From 5797d49d8adab09eb98019b024af8711d8c31c90 Mon Sep 17 00:00:00 2001 From: dekun Date: Thu, 2 Jul 2026 16:23:09 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E5=B0=86=E5=85=B1=E7=94=A8?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E8=BF=81=E5=85=A5=20lib/=20=E6=A8=A1?= =?UTF-8?q?=E5=9D=97=E5=8C=96=E7=9B=AE=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 统一 strategy、key_monitor、trade、hub 等共用库到 lib/ 子包,并补充 lib-structure 文档,便于四所与中控维护。 Co-authored-by: Cursor --- README.md | 7 +- crypto_monitor_binance/app.py | 121 +- crypto_monitor_gate/app.py | 129 +- crypto_monitor_gate_bot/app.py | 129 +- crypto_monitor_okx/app.py | 125 +- docs/lib-structure.md | 147 + lib/__init__.py | 1 + lib/ai/__init__.py | 1 + ai_client.py => lib/ai/ai_client.py | 0 ai_review_lib.py => lib/ai/ai_review_lib.py | 360 +- lib/common/__init__.py | 1 + .../common/auto_transfer_daily_lib.py | 0 .../common/form_submit_lib.py | 0 .../common/history_window_lib.py | 0 .../common/static}/account_risk_badge.css | 0 .../common/static}/account_risk_badge.js | 0 .../common/static}/ai_review_render.js | 0 .../common/static}/focus_chart_page.css | 0 .../common/static}/focus_chart_page.js | 0 .../common/static}/form_submit_guard.js | 0 .../common/static}/instance_embed.js | 0 .../common/static}/instance_page.css | 0 .../common/static}/instance_records_mobile.js | 0 .../common/static}/instance_theme.css | 0 .../common/static}/instance_theme.js | 0 .../common/static}/instance_theme_early.css | 0 {static => lib/common/static}/instance_ui.js | 0 .../common/static}/key_monitor_form.js | 0 .../common/static}/manual_order_rr_preview.js | 0 .../common/static}/strategy_roll.js | 0 .../common/static}/time_close_ui.js | 0 .../common/static}/trade_stats_calendar.css | 0 .../common/static}/trade_stats_calendar.js | 0 .../common/wechat_notify_lib.py | 0 lib/exchange/__init__.py | 1 + .../exchange/gate_position_history_lib.py | 0 .../exchange/gate_transfer_lib.py | 0 .../exchange/okx_orders_lib.py | 0 lib/hub/__init__.py | 1 + hub_auth.py => lib/hub/hub_auth.py | 72 +- hub_bridge.py => lib/hub/hub_bridge.py | 2052 +++--- .../hub/hub_calculator_lib.py | 996 +-- .../hub/hub_calculator_market_lib.py | 514 +- .../hub/hub_entry_plan_lib.py | 0 .../hub/hub_fund_history_lib.py | 814 +-- .../hub/hub_host_status_lib.py | 0 .../hub/hub_kline_store.py | 1762 +++--- .../hub/hub_macro_calendar_lib.py | 622 +- .../hub/hub_market_info_lib.py | 162 +- hub_ohlcv_lib.py => lib/hub/hub_ohlcv_lib.py | 0 .../hub/hub_position_metrics.py | 0 hub_sso.py => lib/hub/hub_sso.py | 0 .../hub/hub_symbol_archive_lib.py | 3352 +++++----- .../hub/hub_trades_lib.py | 1276 ++-- .../hub/hub_volume_rank_lib.py | 1190 ++-- lib/instance/__init__.py | 1 + .../instance/focus_chart_lib.py | 374 +- .../instance/instance_embed_context_lib.py | 168 +- .../instance/instance_embed_lib.py | 295 +- .../instance/instance_nav_lib.py | 0 .../instance/journal_chart_lib.py | 0 .../templates}/embed_boot_scripts.html | 0 .../templates}/embed_page_fragment.html | 0 .../instance/templates}/embed_shell.html | 0 lib/key_monitor/__init__.py | 1 + .../false_breakout_key_monitor_lib.py | 290 +- .../key_monitor/fib_key_monitor_lib.py | 280 +- .../key_monitor_full_margin_lib.py | 122 +- .../key_monitor/key_monitor_lib.py | 780 +-- .../key_monitor/key_monitor_schema_lib.py | 0 .../key_monitor/key_sl_tp_lib.py | 0 .../trigger_entry_key_monitor_lib.py | 592 +- lib/paths.py | 22 + lib/strategy/__init__.py | 1 + .../strategy/strategy_config.py | 0 strategy_db.py => lib/strategy/strategy_db.py | 328 +- .../strategy/strategy_exchange_base.py | 0 .../strategy/strategy_exchange_binance.py | 8 +- .../strategy/strategy_exchange_gate.py | 18 +- .../strategy/strategy_exchange_okx.py | 8 +- .../strategy/strategy_records_register.py | 144 +- .../strategy/strategy_register.py | 1240 ++-- .../strategy/strategy_roll_lib.py | 770 +-- .../strategy/strategy_roll_monitor_lib.py | 1040 +-- .../strategy/strategy_roll_ui_lib.py | 0 .../strategy/strategy_snapshot_lib.py | 1058 ++-- .../strategy/strategy_trade_labels.py | 0 .../strategy/strategy_trend_exchange.py | 0 .../strategy/strategy_trend_lib.py | 1390 ++--- .../strategy/strategy_trend_register.py | 3828 ++++++------ strategy_ui.py => lib/strategy/strategy_ui.py | 288 +- .../strategy/strategy_wechat_notify.py | 384 +- .../templates}/gate_transfer_block.html | 0 .../strategy/templates}/key_focus_v2.html | 0 .../templates}/key_monitor_panel.html | 0 .../templates}/key_monitor_rule_tips.html | 0 .../strategy/templates}/order_focus_v2.html | 0 .../order_monitor_rule_tips_binance.html | 0 .../order_monitor_rule_tips_gate.html | 0 .../order_monitor_rule_tips_gate_bot.html | 0 .../order_monitor_rule_tips_okx.html | 0 .../templates}/order_plan_preview_bar.html | 0 .../templates}/strategy_records_page.html | 0 .../strategy/templates}/strategy_roll.html | 0 .../templates}/strategy_roll_docs.html | 0 .../templates}/strategy_roll_panel.html | 0 .../strategy/templates}/strategy_subnav.html | 0 .../templates}/strategy_trading_page.html | 0 .../templates}/strategy_trend_disabled.html | 0 .../strategy_trend_disabled_panel.html | 0 .../templates}/strategy_trend_panel.html | 0 lib/trade/__init__.py | 1 + .../trade/account_risk_lib.py | 1690 ++--- .../trade/daily_open_limit_lib.py | 0 .../trade/manual_sltp_lib.py | 0 .../trade/order_monitor_display_lib.py | 0 .../trade/position_sizing_lib.py | 0 .../trade/time_close_lib.py | 0 .../trade/trade_exchange_stats_lib.py | 0 .../trade/trade_result_lib.py | 0 .../trade/trade_stats_calendar_lib.py | 0 manual_trading_hub/agent.py | 1816 +++--- manual_trading_hub/exchange_orders.py | 1724 ++--- manual_trading_hub/hub.py | 5552 ++++++++--------- manual_trading_hub/hub_ai/archive_quote.py | 322 +- manual_trading_hub/hub_ai/chat.py | 550 +- manual_trading_hub/hub_ai/client.py | 84 +- manual_trading_hub/hub_ai/context.py | 2140 +++---- manual_trading_hub/hub_ai/fund_history.py | 36 +- manual_trading_hub/hub_ai/routes.py | 400 +- manual_trading_hub/hub_ai/supervisor.py | 250 +- manual_trading_hub/hub_dashboard.py | 214 +- manual_trading_hub/hub_supervisor_lib.py | 1514 ++--- requirements.txt | 1 + scripts/backfill_trend_strategy_snapshots.py | 496 +- scripts/backfill_trend_trade_records.py | 376 +- scripts/build_embed_fragment.py | 66 +- scripts/clear_hub_kline_db.py | 186 +- scripts/dedupe_strategy_snapshots.py | 134 +- scripts/extract_instance_page_assets.py | 98 +- scripts/fix_trend_handoff_monitor_type.py | 156 +- scripts/migrate_to_lib.py | 252 + scripts/patch_position_sizing_to_exchanges.py | 394 +- tests/test_account_risk_lib.py | 1052 ++-- tests/test_ai_review_lib.py | 126 +- tests/test_archive_calendar.py | 120 +- tests/test_daily_open_limit_lib.py | 180 +- tests/test_false_breakout_key_monitor_lib.py | 152 +- tests/test_gate_position_history_lib.py | 52 +- tests/test_gate_transfer_lib.py | 88 +- tests/test_hub_agent_mark_price.py | 188 +- tests/test_hub_calculator_lib.py | 324 +- tests/test_hub_calculator_market_lib.py | 226 +- tests/test_hub_entry_plan_lib.py | 314 +- tests/test_hub_fund_history_lib.py | 226 +- tests/test_hub_host_status_lib.py | 116 +- tests/test_hub_kline_store.py | 932 +-- tests/test_hub_macro_calendar_lib.py | 146 +- tests/test_hub_monitor_payload.py | 78 +- tests/test_hub_ohlcv_lib.py | 445 +- tests/test_hub_supervisor_lib.py | 7 +- tests/test_hub_symbol_archive_lib.py | 696 +-- tests/test_hub_trades_archive_merge.py | 204 +- tests/test_hub_trades_lib.py | 396 +- tests/test_hub_trades_review_fields.py | 230 +- tests/test_hub_volume_rank_lib.py | 368 +- tests/test_instance_embed_context_lib.py | 56 +- tests/test_instance_embed_lib.py | 52 +- tests/test_instance_nav_lib.py | 42 +- tests/test_key_monitor_box_invalidate.py | 68 +- tests/test_key_monitor_rs_alert.py | 172 +- tests/test_key_monitor_rs_type.py | 54 +- tests/test_manual_sltp_lib.py | 64 +- tests/test_order_monitor_display_lib.py | 204 +- tests/test_position_limit_count.py | 156 +- tests/test_position_sizing_risk_display.py | 68 +- tests/test_strategy_roll_lib.py | 224 +- tests/test_strategy_roll_ui_lib.py | 88 +- tests/test_strategy_snapshot_dedup.py | 366 +- tests/test_trade_exchange_stats_lib.py | 96 +- tests/test_trade_result_lib.py | 60 +- tests/test_trade_stats_calendar_lib.py | 180 +- tests/test_trend_dca_enrich_fills.py | 202 +- tests/test_trend_dca_pnl.py | 86 +- tests/test_trend_finalize_trade_record.py | 184 +- tests/test_trend_hub_enrich.py | 88 +- tests/test_trend_hub_enrich_unified.py | 184 +- tests/test_trend_market_add_params.py | 82 +- tests/test_trend_preview_tp.py | 116 +- tests/test_trigger_entry_key_monitor_lib.py | 170 +- 190 files changed, 27946 insertions(+), 27499 deletions(-) create mode 100644 docs/lib-structure.md create mode 100644 lib/__init__.py create mode 100644 lib/ai/__init__.py rename ai_client.py => lib/ai/ai_client.py (100%) rename ai_review_lib.py => lib/ai/ai_review_lib.py (96%) create mode 100644 lib/common/__init__.py rename auto_transfer_daily_lib.py => lib/common/auto_transfer_daily_lib.py (100%) rename form_submit_lib.py => lib/common/form_submit_lib.py (100%) rename history_window_lib.py => lib/common/history_window_lib.py (100%) rename {static => lib/common/static}/account_risk_badge.css (100%) rename {static => lib/common/static}/account_risk_badge.js (100%) rename {static => lib/common/static}/ai_review_render.js (100%) rename {static => lib/common/static}/focus_chart_page.css (100%) rename {static => lib/common/static}/focus_chart_page.js (100%) rename {static => lib/common/static}/form_submit_guard.js (100%) rename {static => lib/common/static}/instance_embed.js (100%) rename {static => lib/common/static}/instance_page.css (100%) rename {static => lib/common/static}/instance_records_mobile.js (100%) rename {static => lib/common/static}/instance_theme.css (100%) rename {static => lib/common/static}/instance_theme.js (100%) rename {static => lib/common/static}/instance_theme_early.css (100%) rename {static => lib/common/static}/instance_ui.js (100%) rename {static => lib/common/static}/key_monitor_form.js (100%) rename {static => lib/common/static}/manual_order_rr_preview.js (100%) rename {static => lib/common/static}/strategy_roll.js (100%) rename {static => lib/common/static}/time_close_ui.js (100%) rename {static => lib/common/static}/trade_stats_calendar.css (100%) rename {static => lib/common/static}/trade_stats_calendar.js (100%) rename wechat_notify_lib.py => lib/common/wechat_notify_lib.py (100%) create mode 100644 lib/exchange/__init__.py rename gate_position_history_lib.py => lib/exchange/gate_position_history_lib.py (100%) rename gate_transfer_lib.py => lib/exchange/gate_transfer_lib.py (100%) rename okx_orders_lib.py => lib/exchange/okx_orders_lib.py (100%) create mode 100644 lib/hub/__init__.py rename hub_auth.py => lib/hub/hub_auth.py (92%) rename hub_bridge.py => lib/hub/hub_bridge.py (94%) rename hub_calculator_lib.py => lib/hub/hub_calculator_lib.py (96%) rename hub_calculator_market_lib.py => lib/hub/hub_calculator_market_lib.py (96%) rename hub_entry_plan_lib.py => lib/hub/hub_entry_plan_lib.py (100%) rename hub_fund_history_lib.py => lib/hub/hub_fund_history_lib.py (96%) rename hub_host_status_lib.py => lib/hub/hub_host_status_lib.py (100%) rename hub_kline_store.py => lib/hub/hub_kline_store.py (96%) rename hub_macro_calendar_lib.py => lib/hub/hub_macro_calendar_lib.py (96%) rename hub_market_info_lib.py => lib/hub/hub_market_info_lib.py (92%) rename hub_ohlcv_lib.py => lib/hub/hub_ohlcv_lib.py (100%) rename hub_position_metrics.py => lib/hub/hub_position_metrics.py (100%) rename hub_sso.py => lib/hub/hub_sso.py (100%) rename hub_symbol_archive_lib.py => lib/hub/hub_symbol_archive_lib.py (97%) rename hub_trades_lib.py => lib/hub/hub_trades_lib.py (96%) rename hub_volume_rank_lib.py => lib/hub/hub_volume_rank_lib.py (96%) create mode 100644 lib/instance/__init__.py rename focus_chart_lib.py => lib/instance/focus_chart_lib.py (95%) rename instance_embed_context_lib.py => lib/instance/instance_embed_context_lib.py (94%) rename instance_embed_lib.py => lib/instance/instance_embed_lib.py (95%) rename instance_nav_lib.py => lib/instance/instance_nav_lib.py (100%) rename journal_chart_lib.py => lib/instance/journal_chart_lib.py (100%) rename {embed_templates => lib/instance/templates}/embed_boot_scripts.html (100%) rename {embed_templates => lib/instance/templates}/embed_page_fragment.html (100%) rename {embed_templates => lib/instance/templates}/embed_shell.html (100%) create mode 100644 lib/key_monitor/__init__.py rename false_breakout_key_monitor_lib.py => lib/key_monitor/false_breakout_key_monitor_lib.py (95%) rename fib_key_monitor_lib.py => lib/key_monitor/fib_key_monitor_lib.py (96%) rename key_monitor_full_margin_lib.py => lib/key_monitor/key_monitor_full_margin_lib.py (83%) rename key_monitor_lib.py => lib/key_monitor/key_monitor_lib.py (95%) rename key_monitor_schema_lib.py => lib/key_monitor/key_monitor_schema_lib.py (100%) rename key_sl_tp_lib.py => lib/key_monitor/key_sl_tp_lib.py (100%) rename trigger_entry_key_monitor_lib.py => lib/key_monitor/trigger_entry_key_monitor_lib.py (95%) create mode 100644 lib/paths.py create mode 100644 lib/strategy/__init__.py rename strategy_config.py => lib/strategy/strategy_config.py (100%) rename strategy_db.py => lib/strategy/strategy_db.py (95%) rename strategy_exchange_base.py => lib/strategy/strategy_exchange_base.py (100%) rename strategy_exchange_binance.py => lib/strategy/strategy_exchange_binance.py (66%) rename strategy_exchange_gate.py => lib/strategy/strategy_exchange_gate.py (80%) rename strategy_exchange_okx.py => lib/strategy/strategy_exchange_okx.py (64%) rename strategy_records_register.py => lib/strategy/strategy_records_register.py (94%) rename strategy_register.py => lib/strategy/strategy_register.py (94%) rename strategy_roll_lib.py => lib/strategy/strategy_roll_lib.py (96%) rename strategy_roll_monitor_lib.py => lib/strategy/strategy_roll_monitor_lib.py (94%) rename strategy_roll_ui_lib.py => lib/strategy/strategy_roll_ui_lib.py (100%) rename strategy_snapshot_lib.py => lib/strategy/strategy_snapshot_lib.py (96%) rename strategy_trade_labels.py => lib/strategy/strategy_trade_labels.py (100%) rename strategy_trend_exchange.py => lib/strategy/strategy_trend_exchange.py (100%) rename strategy_trend_lib.py => lib/strategy/strategy_trend_lib.py (96%) rename strategy_trend_register.py => lib/strategy/strategy_trend_register.py (95%) rename strategy_ui.py => lib/strategy/strategy_ui.py (89%) rename strategy_wechat_notify.py => lib/strategy/strategy_wechat_notify.py (95%) rename {strategy_templates => lib/strategy/templates}/gate_transfer_block.html (100%) rename {strategy_templates => lib/strategy/templates}/key_focus_v2.html (100%) rename {strategy_templates => lib/strategy/templates}/key_monitor_panel.html (100%) rename {strategy_templates => lib/strategy/templates}/key_monitor_rule_tips.html (100%) rename {strategy_templates => lib/strategy/templates}/order_focus_v2.html (100%) rename {strategy_templates => lib/strategy/templates}/order_monitor_rule_tips_binance.html (100%) rename {strategy_templates => lib/strategy/templates}/order_monitor_rule_tips_gate.html (100%) rename {strategy_templates => lib/strategy/templates}/order_monitor_rule_tips_gate_bot.html (100%) rename {strategy_templates => lib/strategy/templates}/order_monitor_rule_tips_okx.html (100%) rename {strategy_templates => lib/strategy/templates}/order_plan_preview_bar.html (100%) rename {strategy_templates => lib/strategy/templates}/strategy_records_page.html (100%) rename {strategy_templates => lib/strategy/templates}/strategy_roll.html (100%) rename {strategy_templates => lib/strategy/templates}/strategy_roll_docs.html (100%) rename {strategy_templates => lib/strategy/templates}/strategy_roll_panel.html (100%) rename {strategy_templates => lib/strategy/templates}/strategy_subnav.html (100%) rename {strategy_templates => lib/strategy/templates}/strategy_trading_page.html (100%) rename {strategy_templates => lib/strategy/templates}/strategy_trend_disabled.html (100%) rename {strategy_templates => lib/strategy/templates}/strategy_trend_disabled_panel.html (100%) rename {strategy_templates => lib/strategy/templates}/strategy_trend_panel.html (100%) create mode 100644 lib/trade/__init__.py rename account_risk_lib.py => lib/trade/account_risk_lib.py (96%) rename daily_open_limit_lib.py => lib/trade/daily_open_limit_lib.py (100%) rename manual_sltp_lib.py => lib/trade/manual_sltp_lib.py (100%) rename order_monitor_display_lib.py => lib/trade/order_monitor_display_lib.py (100%) rename position_sizing_lib.py => lib/trade/position_sizing_lib.py (100%) rename time_close_lib.py => lib/trade/time_close_lib.py (100%) rename trade_exchange_stats_lib.py => lib/trade/trade_exchange_stats_lib.py (100%) rename trade_result_lib.py => lib/trade/trade_result_lib.py (100%) rename trade_stats_calendar_lib.py => lib/trade/trade_stats_calendar_lib.py (100%) create mode 100644 scripts/migrate_to_lib.py diff --git a/README.md b/README.md index 0a40a7e..1bc7685 100644 --- a/README.md +++ b/README.md @@ -53,8 +53,11 @@ bash deploy/setup_env.sh --install-system-deps | `crypto_monitor_gate_bot/` | Gate 机器人 / 趋势户 | [部署文档.md](./crypto_monitor_gate_bot/部署文档.md) | | `crypto_monitor_okx/` | OKX 永续 | [部署文档.md](./crypto_monitor_okx/部署文档.md) | | `manual_trading_hub/` | 中控 + 子代理 | [部署文档.md](./manual_trading_hub/部署文档.md) | -| 根目录 `strategy_*.py` | 策略共用库 | [策略交易说明.md](./策略交易说明.md) | -| 根目录 `key_*_lib.py` | 关键位 / 止盈止损共用库 | [关键位止盈止损与移动保本更新说明.md](./关键位止盈止损与移动保本更新说明.md) | +| `lib/` | **共用模块**(策略、关键位、交易、中控库、AI、静态与模板) | **[docs/lib-structure.md](./docs/lib-structure.md)** | +| `brand/` | 各所共用图标与 manifest | — | +| `docs/`、`deploy/`、`scripts/`、`tests/` | 文档、环境、脚本、单元测试 | — | + +共用代码 import 示例:`from lib.strategy.strategy_db import init_strategy_tables`(各所启动时仍将仓库根加入 `PYTHONPATH`)。详见 **[docs/lib-structure.md](./docs/lib-structure.md)**。 --- diff --git a/crypto_monitor_binance/app.py b/crypto_monitor_binance/app.py index a043fc4..8dcb5b4 100644 --- a/crypto_monitor_binance/app.py +++ b/crypto_monitor_binance/app.py @@ -34,14 +34,15 @@ import sys if _REPO_ROOT not in sys.path: sys.path.insert(0, _REPO_ROOT) -from ai_client import ai_generate, ai_review, ai_short_advice -from ai_review_lib import ( +from lib.paths import common_static_dir +from lib.ai.ai_client import ai_generate, ai_review, ai_short_advice +from lib.ai.ai_review_lib import ( build_journal_ai_chart_path, collect_images_for_ai_review, journal_row_lines_for_ai, ) -from form_submit_lib import check_duplicate_submit, submit_scope_add_key, submit_scope_add_order -from fib_key_monitor_lib import ( +from lib.common.form_submit_lib import check_duplicate_submit, submit_scope_add_key, submit_scope_add_order +from lib.key_monitor.fib_key_monitor_lib import ( FIB_KEY_MONITOR_TYPES, backfill_missing_key_signal_types, calc_fib_plan, @@ -52,7 +53,7 @@ from fib_key_monitor_lib import ( key_signal_type_for_trade_record, stored_key_signal_type, ) -from false_breakout_key_monitor_lib import ( +from lib.key_monitor.false_breakout_key_monitor_lib import ( FALSE_BREAKOUT_MONITOR_TYPE, FALSE_BREAKOUT_VALIDITY_HOURS, calc_false_breakout_plan, @@ -65,7 +66,7 @@ from false_breakout_key_monitor_lib import ( normalize_false_breakout_symbol, storage_bounds_from_key_price, ) -from strategy_trade_labels import ( +from lib.strategy.strategy_trade_labels import ( STRATEGY_ENTRY_REASON_OPTIONS, apply_order_monitor_source_labels, entry_reason_for_monitor_type, @@ -74,7 +75,7 @@ from strategy_trade_labels import ( trade_record_monitor_type as resolve_trade_record_monitor_type, trend_plan_id_from_monitor_row, ) -from journal_chart_lib import ( +from lib.instance.journal_chart_lib import ( JOURNAL_CHART_DEFAULT_LIMIT, JOURNAL_CHART_DEFAULT_TF1, JOURNAL_CHART_DEFAULT_TF2, @@ -90,7 +91,7 @@ from journal_chart_lib import ( trade_review_fetch_window, trim_rows_for_trade_review, ) -from key_sl_tp_lib import ( +from lib.key_monitor.key_sl_tp_lib import ( breakeven_enabled_from_row, normalize_sl_tp_mode, parse_breakeven_enabled_form, @@ -99,7 +100,7 @@ from key_sl_tp_lib import ( sl_tp_mode_label, sl_tp_plan_summary_text, ) -from time_close_lib import ( +from lib.trade.time_close_lib import ( TIME_CLOSE_RESULT, apply_time_close_to_payload, ensure_time_close_schema, @@ -110,13 +111,13 @@ from time_close_lib import ( time_close_label, time_close_settings_from_row, ) -from manual_sltp_lib import ( +from lib.trade.manual_sltp_lib import ( normalize_open_sltp_mode, resolve_entrust_sltp_prices, resolve_open_sltp_prices, ) -from key_monitor_schema_lib import ensure_key_monitor_schema -from trigger_entry_key_monitor_lib import ( +from lib.key_monitor.key_monitor_schema_lib import ensure_key_monitor_schema +from lib.key_monitor.trigger_entry_key_monitor_lib import ( BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE, CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE, TRIGGER_ENTRY_CLOSE_EXCHANGE_FAILED, @@ -139,7 +140,7 @@ from trigger_entry_key_monitor_lib import ( validate_trigger_entry_geometry, validate_trigger_entry_rr, ) -from position_sizing_lib import ( +from lib.trade.position_sizing_lib import ( OPEN_SOURCE_KEY_AUTO, OPEN_SOURCE_KEY_TRIGGER, OPEN_SOURCE_MANUAL, @@ -155,12 +156,12 @@ from position_sizing_lib import ( mode_label_zh, risk_percent_for_storage, ) -from key_monitor_full_margin_lib import ( +from lib.key_monitor.key_monitor_full_margin_lib import ( monitor_type_disallowed_in_full_margin, purge_disallowed_key_monitors, ) -from auto_transfer_daily_lib import run_auto_transfer_once_per_day -from key_monitor_lib import ( +from lib.common.auto_transfer_daily_lib import run_auto_transfer_once_per_day +from lib.key_monitor.key_monitor_lib import ( KEY_DIRECTION_WATCH, KEY_MONITOR_ALERT_ONLY_TYPES, KEY_MONITOR_AUTO_TYPES, @@ -180,15 +181,15 @@ from key_monitor_lib import ( rs_break_from_direction, run_rs_level_alert_tick, ) -from order_monitor_display_lib import ( +from lib.trade.order_monitor_display_lib import ( apply_order_price_display_fields, enrich_order_display_fields, order_monitor_tpsl_needs_sync, ) -from wechat_notify_lib import build_wechat_rs_level_message, send_wechat_webhook -from hub_auth import request_allowed as hub_request_allowed -from hub_volume_rank_lib import resolve_daily_volume_rank -from history_window_lib import ( +from lib.common.wechat_notify_lib import build_wechat_rs_level_message, send_wechat_webhook +from lib.hub.hub_auth import request_allowed as hub_request_allowed +from lib.hub.hub_volume_rank_lib import resolve_daily_volume_rank +from lib.common.history_window_lib import ( PRESET_CUSTOM, PRESET_UTC_LAST24H, PRESET_UTC_LAST7D, @@ -201,8 +202,8 @@ from history_window_lib import ( utc_window_to_bj_sql_strings, utc_window_to_utc_sql_strings, ) -from trade_result_lib import count_winning_trades, normalize_result_with_pnl -from trade_exchange_stats_lib import ( +from lib.trade.trade_result_lib import count_winning_trades, normalize_result_with_pnl +from lib.trade.trade_exchange_stats_lib import ( attach_exchange_stats_to_trade, filter_position_lifecycle_fills, sum_binance_commission_income, @@ -353,7 +354,7 @@ ORDER_CHART_ENABLED = os.getenv("ORDER_CHART_ENABLED", "true").lower() == "true" ORDER_CHART_TFS = [x.strip() for x in (os.getenv("ORDER_CHART_TFS", "4h,1h,15m,5m") or "").split(",") if x.strip()] ORDER_CHART_LIMIT = int(os.getenv("ORDER_CHART_LIMIT", "100")) ORDER_CHART_DIR = resolve_path(os.getenv("ORDER_CHART_DIR", "static/images/order_charts")) -from daily_open_limit_lib import ( +from lib.trade.daily_open_limit_lib import ( build_daily_open_alert_prompt, can_trade_new_open, check_daily_open_hard_limit, @@ -1520,10 +1521,10 @@ def init_db(): close_reason TEXT, closed_at TEXT)""" ) - from strategy_db import init_strategy_tables + from lib.strategy.strategy_db import init_strategy_tables init_strategy_tables(conn) - from account_risk_lib import ensure_account_risk_schema + from lib.trade.account_risk_lib import ensure_account_risk_schema ensure_account_risk_schema(conn) backfill_missing_key_signal_types(conn, monitor_type=ORDER_MONITOR_TYPE_KEY_AUTO) @@ -1560,7 +1561,7 @@ def get_db(): def hub_account_risk_status(conn): - from account_risk_lib import ( + from lib.trade.account_risk_lib import ( apply_position_limit_risk, compute_account_risk_status, enrich_risk_status_countdown, @@ -1576,7 +1577,7 @@ def hub_account_risk_status(conn): fmt_local_ms=ms_to_app_local_str, ) st = enrich_risk_status_countdown(st, now=now, daily_reset_hour=TRADING_DAY_RESET_HOUR) - from strategy_trade_labels import count_position_limit_active_monitors + from lib.strategy.strategy_trade_labels import count_position_limit_active_monitors return apply_position_limit_risk( st, @@ -1593,7 +1594,7 @@ def hub_user_initiated_close( trade_record_id=None, closed_at_ms=None, ): - from account_risk_lib import CLOSE_SOURCE_USER_HUB, on_user_initiated_close + from lib.trade.account_risk_lib import CLOSE_SOURCE_USER_HUB, on_user_initiated_close src = (source or "").strip() or CLOSE_SOURCE_USER_HUB on_user_initiated_close( @@ -2120,7 +2121,7 @@ def get_effective_trade_field(row, reviewed_key, base_key, default=None): def to_effective_trade_dict(row): item = row_to_dict(row) - from order_monitor_display_lib import snapshot_stop_loss + from lib.trade.order_monitor_display_lib import snapshot_stop_loss open_stop = snapshot_stop_loss(item.get("initial_stop_loss"), item.get("stop_loss")) item["display_open_stop_loss"] = open_stop @@ -2661,7 +2662,7 @@ def insert_trade_record( open_ts_ms = _to_ms_with_fallback(opened_at_ms, open_ts) close_ts_ms = _to_ms_with_fallback(closed_at_ms, close_ts) kst = key_signal_type_for_trade_record(key_signal_type, KEY_MONITOR_AUTO_TYPES) - from order_monitor_display_lib import snapshot_stop_loss + from lib.trade.order_monitor_display_lib import snapshot_stop_loss snap_sl = snapshot_stop_loss(initial_stop_loss, stop_loss) er = ( @@ -3193,7 +3194,7 @@ def resolve_capital_base_for_key_open(conn, trading_day, live_capital): def precheck_risk(conn, symbol, direction): now = app_now() - from account_risk_lib import account_risk_blocks_trading + from lib.trade.account_risk_lib import account_risk_blocks_trading ok_risk, risk_reason = account_risk_blocks_trading( conn, @@ -3205,7 +3206,7 @@ def precheck_risk(conn, symbol, direction): return False, risk_reason if not trading_day_reset_allows_new_open(now): return False, f"北京时间 {TRADING_DAY_RESET_HOUR}:00 前不允许持仓" - from account_risk_lib import position_limit_reached + from lib.trade.account_risk_lib import position_limit_reached reached, active_count, mx = position_limit_reached(conn, max_active_positions=MAX_ACTIVE_POSITIONS) if reached: @@ -3898,7 +3899,7 @@ def list_orphan_live_positions(conn): ex = normalize_exchange_symbol(r["exchange_symbol"] or r["symbol"]) active_keys.add((ex, (r["direction"] or "long").strip().lower())) - from hub_position_metrics import parse_position_entry_price + from lib.hub.hub_position_metrics import parse_position_entry_price orphans = [] for lp in live_rows: @@ -4097,7 +4098,7 @@ def parse_ccxt_position_metrics(position, order_leverage=None): cs = float(get_contract_size(sym)) if sym else 1.0 except Exception: cs = 1.0 - from hub_position_metrics import enrich_ccxt_position_metrics_out + from lib.hub.hub_position_metrics import enrich_ccxt_position_metrics_out enrich_ccxt_position_metrics_out( p, out, contract_size=cs, funds_decimals=FUNDS_DECIMALS @@ -6807,14 +6808,14 @@ def background_task(): check_trigger_entry_key_monitors() _roll_cfg = app.extensions.get("strategy_roll_cfg") if _roll_cfg: - from strategy_roll_monitor_lib import check_roll_monitors + from lib.strategy.strategy_roll_monitor_lib import check_roll_monitors check_roll_monitors(_roll_cfg) check_key_monitors() check_order_monitors() cfg = app.extensions.get("strategy_trend_cfg") if cfg: - from strategy_trend_register import check_trend_pullback_plans + from lib.strategy.strategy_trend_register import check_trend_pullback_plans check_trend_pullback_plans(cfg) except: @@ -7006,7 +7007,7 @@ def render_main_page(page="trade", embed_mode=None): conn = get_db() session_row = ensure_session(conn, trading_day) local_current_capital = float(session_row["current_capital"]) - from instance_embed_context_lib import ( + from lib.instance.instance_embed_context_lib import ( embed_render_plan, minimal_stats_bundle, trade_records_summary, @@ -7070,7 +7071,7 @@ def render_main_page(page="trade", embed_mode=None): records = [] total = miss_count = rate = occupied_miss_total = 0 active_count = len(order_list) - from strategy_trade_labels import count_position_limit_active_monitors + from lib.strategy.strategy_trade_labels import count_position_limit_active_monitors position_limit_count = count_position_limit_active_monitors(conn) opens_today = count_opens_for_trading_day(conn, trading_day) @@ -7101,7 +7102,7 @@ def render_main_page(page="trade", embed_mode=None): ) strategy_extra = {} if plan.strategy: - from strategy_ui import strategy_render_extras + from lib.strategy.strategy_ui import strategy_render_extras strategy_extra = strategy_render_extras( conn, @@ -7114,7 +7115,7 @@ def render_main_page(page="trade", embed_mode=None): if plan.orphan_live and not order_list and exchange_private_api_configured(): orphan_live_positions = list_orphan_live_positions(conn) conn.close() - from instance_embed_lib import embed_context_extras + from lib.instance.instance_embed_lib import embed_context_extras template_ctx = dict( page=page, @@ -7236,7 +7237,7 @@ def api_account_snapshot(): funding_usdt = round(funding_capital, FUNDS_DECIMALS) if funding_capital is not None else None current_capital = round(trading_capital, FUNDS_DECIMALS) if trading_capital is not None else round(local_current_capital, FUNDS_DECIMALS) recommended_capital = get_recommended_capital(current_capital) - from strategy_trade_labels import count_position_limit_active_monitors + from lib.strategy.strategy_trade_labels import count_position_limit_active_monitors position_limit_count = count_position_limit_active_monitors(conn) opens_today = count_opens_for_trading_day(conn, trading_day) @@ -7523,7 +7524,7 @@ def api_price_snapshot(): pass conn.close() - from hub_position_metrics import build_position_marks_list + from lib.hub.hub_position_metrics import build_position_marks_list position_marks = build_position_marks_list( all_swap_positions, @@ -7782,7 +7783,7 @@ def api_order_kline(): "volume": float(bar[5]), }) - from focus_chart_lib import ( + from lib.instance.focus_chart_lib import ( build_order_kline_order_payload, load_swap_positions_for_order_kline, metrics_for_order_item, @@ -7810,7 +7811,7 @@ def api_order_kline(): ex_metrics=ex_metrics, ) - from focus_chart_lib import kline_api_price_fields + from lib.instance.focus_chart_lib import kline_api_price_fields price_fields = kline_api_price_fields( exchange, @@ -7927,7 +7928,7 @@ def api_key_kline(): "lower_pct": lower_pct, } - from focus_chart_lib import enrich_key_kline_response + from lib.instance.focus_chart_lib import enrich_key_kline_response price_display, key_info = enrich_key_kline_response( symbol=symbol, @@ -7936,7 +7937,7 @@ def api_key_kline(): format_price_fn=format_price_for_symbol, ) - from focus_chart_lib import kline_api_price_fields + from lib.instance.focus_chart_lib import kline_api_price_fields price_fields = kline_api_price_fields( exchange, @@ -8859,7 +8860,7 @@ def del_order(id): opened_at=opened_at, closed_at=closed_at, ) - from account_risk_lib import CLOSE_SOURCE_USER_INSTANCE, insert_trade_record_id, on_user_initiated_close + from lib.trade.account_risk_lib import CLOSE_SOURCE_USER_INSTANCE, insert_trade_record_id, on_user_initiated_close on_user_initiated_close( conn, @@ -8873,7 +8874,7 @@ def del_order(id): try: _rcfg = app.extensions.get("strategy_roll_cfg") if isinstance(_rcfg, dict): - from strategy_register import roll_sync_after_external_close + from lib.strategy.strategy_register import roll_sync_after_external_close roll_sync_after_external_close(_rcfg, conn, row["symbol"], row["direction"]) except Exception: @@ -8934,7 +8935,7 @@ def del_order(id): opened_at=opened_at, closed_at=closed_at, ) - from account_risk_lib import CLOSE_SOURCE_USER_INSTANCE, insert_trade_record_id, on_user_initiated_close + from lib.trade.account_risk_lib import CLOSE_SOURCE_USER_INSTANCE, insert_trade_record_id, on_user_initiated_close on_user_initiated_close( conn, @@ -8948,7 +8949,7 @@ def del_order(id): try: _rcfg = app.extensions.get("strategy_roll_cfg") if isinstance(_rcfg, dict): - from strategy_register import roll_sync_after_external_close + from lib.strategy.strategy_register import roll_sync_after_external_close roll_sync_after_external_close(_rcfg, conn, row["symbol"], row["direction"]) except Exception: @@ -9109,7 +9110,7 @@ def add_journal(): d.get("post_breakeven_stare"), d.get("new_trade_while_occupied"), d.get("note"), image_filename ) ) - from account_risk_lib import on_journal_saved + from lib.trade.account_risk_lib import on_journal_saved on_journal_saved( conn, @@ -9197,7 +9198,7 @@ def api_reviews(): return jsonify([row_to_dict(r) for r in rows]) -_REPO_STATIC_DIR = os.path.join(os.path.dirname(BASE_DIR), "static") +_REPO_STATIC_DIR = common_static_dir(os.path.dirname(BASE_DIR)) _AI_REVIEW_RENDER_JS = os.path.join(_REPO_STATIC_DIR, "ai_review_render.js") _FORM_SUBMIT_GUARD_JS = os.path.join(_REPO_STATIC_DIR, "form_submit_guard.js") _MANUAL_ORDER_RR_PREVIEW_JS = os.path.join(_REPO_STATIC_DIR, "manual_order_rr_preview.js") @@ -9422,7 +9423,7 @@ def api_trade_record_review_update(): tuple(base_params + [rec_id]), ) if reviewed_result == "手动平仓" and reviewed_miss_reason: - from account_risk_lib import apply_manual_close_journal_cooloff + from lib.trade.account_risk_lib import apply_manual_close_journal_cooloff apply_manual_close_journal_cooloff( conn, @@ -9571,7 +9572,7 @@ def _hub_account_bundle(): def _hub_fetch_market(base=""): - from hub_market_info_lib import fetch_usdt_swap_market_info + from lib.hub.hub_market_info_lib import fetch_usdt_swap_market_info return fetch_usdt_swap_market_info( base_or_symbol=base, @@ -9584,7 +9585,7 @@ def _hub_fetch_market(base=""): def _hub_fetch_ohlcv(symbol, timeframe, since_ms=None, limit=500): - from hub_ohlcv_lib import fetch_ohlcv_for_hub + from lib.hub.hub_ohlcv_lib import fetch_ohlcv_for_hub return fetch_ohlcv_for_hub( symbol=symbol, @@ -9600,7 +9601,7 @@ def _hub_fetch_ohlcv(symbol, timeframe, since_ms=None, limit=500): def _hub_fetch_volume_rank(top_n=20): - from hub_volume_rank_lib import fetch_usdt_swap_volume_rank + from lib.hub.hub_volume_rank_lib import fetch_usdt_swap_volume_rank return fetch_usdt_swap_volume_rank( exchange=exchange, @@ -9617,7 +9618,7 @@ try: _repo_root = Path(__file__).resolve().parent.parent if str(_repo_root) not in sys.path: sys.path.insert(0, str(_repo_root)) - from hub_bridge import install_on_app + from lib.hub.hub_bridge import install_on_app install_on_app( app, @@ -9660,8 +9661,8 @@ def strategy_roll_page(): return redirect("/strategy") -from strategy_register import install_strategy_trading -from strategy_trend_register import install_strategy_trend +from lib.strategy.strategy_register import install_strategy_trading +from lib.strategy.strategy_trend_register import install_strategy_trend install_strategy_trading(app, _REPO_ROOT, app_module=sys.modules[__name__]) install_strategy_trend(app, _REPO_ROOT, app_module=sys.modules[__name__]) diff --git a/crypto_monitor_gate/app.py b/crypto_monitor_gate/app.py index 4bfe51d..6b84d96 100644 --- a/crypto_monitor_gate/app.py +++ b/crypto_monitor_gate/app.py @@ -34,14 +34,15 @@ import sys if _REPO_ROOT not in sys.path: sys.path.insert(0, _REPO_ROOT) -from ai_client import ai_generate, ai_review, ai_short_advice -from ai_review_lib import ( +from lib.paths import common_static_dir +from lib.ai.ai_client import ai_generate, ai_review, ai_short_advice +from lib.ai.ai_review_lib import ( build_journal_ai_chart_path, collect_images_for_ai_review, journal_row_lines_for_ai, ) -from form_submit_lib import check_duplicate_submit, submit_scope_add_key, submit_scope_add_order -from fib_key_monitor_lib import ( +from lib.common.form_submit_lib import check_duplicate_submit, submit_scope_add_key, submit_scope_add_order +from lib.key_monitor.fib_key_monitor_lib import ( FIB_KEY_MONITOR_TYPES, KEY_ENTRY_REASON_BY_SIGNAL, backfill_missing_key_signal_types, @@ -53,7 +54,7 @@ from fib_key_monitor_lib import ( key_signal_type_for_trade_record, stored_key_signal_type, ) -from false_breakout_key_monitor_lib import ( +from lib.key_monitor.false_breakout_key_monitor_lib import ( FALSE_BREAKOUT_MONITOR_TYPE, FALSE_BREAKOUT_VALIDITY_HOURS, calc_false_breakout_plan, @@ -66,7 +67,7 @@ from false_breakout_key_monitor_lib import ( normalize_false_breakout_symbol, storage_bounds_from_key_price, ) -from strategy_trade_labels import ( +from lib.strategy.strategy_trade_labels import ( STRATEGY_ENTRY_REASON_OPTIONS, apply_order_monitor_source_labels, entry_reason_for_monitor_type, @@ -75,7 +76,7 @@ from strategy_trade_labels import ( trade_record_monitor_type as resolve_trade_record_monitor_type, trend_plan_id_from_monitor_row, ) -from journal_chart_lib import ( +from lib.instance.journal_chart_lib import ( JOURNAL_CHART_DEFAULT_LIMIT, JOURNAL_CHART_DEFAULT_TF1, JOURNAL_CHART_DEFAULT_TF2, @@ -91,7 +92,7 @@ from journal_chart_lib import ( trade_review_fetch_window, trim_rows_for_trade_review, ) -from key_sl_tp_lib import ( +from lib.key_monitor.key_sl_tp_lib import ( breakeven_enabled_from_row, normalize_sl_tp_mode, parse_breakeven_enabled_form, @@ -100,7 +101,7 @@ from key_sl_tp_lib import ( sl_tp_mode_label, sl_tp_plan_summary_text, ) -from time_close_lib import ( +from lib.trade.time_close_lib import ( TIME_CLOSE_RESULT, apply_time_close_to_payload, ensure_time_close_schema, @@ -111,13 +112,13 @@ from time_close_lib import ( time_close_label, time_close_settings_from_row, ) -from manual_sltp_lib import ( +from lib.trade.manual_sltp_lib import ( normalize_open_sltp_mode, resolve_entrust_sltp_prices, resolve_open_sltp_prices, ) -from key_monitor_schema_lib import ensure_key_monitor_schema -from trigger_entry_key_monitor_lib import ( +from lib.key_monitor.key_monitor_schema_lib import ensure_key_monitor_schema +from lib.key_monitor.trigger_entry_key_monitor_lib import ( BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE, CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE, TRIGGER_ENTRY_CLOSE_EXCHANGE_FAILED, @@ -140,7 +141,7 @@ from trigger_entry_key_monitor_lib import ( validate_trigger_entry_geometry, validate_trigger_entry_rr, ) -from position_sizing_lib import ( +from lib.trade.position_sizing_lib import ( OPEN_SOURCE_KEY_AUTO, OPEN_SOURCE_KEY_TRIGGER, OPEN_SOURCE_MANUAL, @@ -154,12 +155,12 @@ from position_sizing_lib import ( mode_label_zh, risk_percent_for_storage, ) -from key_monitor_full_margin_lib import ( +from lib.key_monitor.key_monitor_full_margin_lib import ( monitor_type_disallowed_in_full_margin, purge_disallowed_key_monitors, ) -from auto_transfer_daily_lib import run_auto_transfer_once_per_day -from key_monitor_lib import ( +from lib.common.auto_transfer_daily_lib import run_auto_transfer_once_per_day +from lib.key_monitor.key_monitor_lib import ( KEY_DIRECTION_WATCH, KEY_MONITOR_ALERT_ONLY_TYPES, KEY_MONITOR_AUTO_TYPES, @@ -179,16 +180,16 @@ from key_monitor_lib import ( rs_break_from_direction, run_rs_level_alert_tick, ) -from order_monitor_display_lib import ( +from lib.trade.order_monitor_display_lib import ( apply_order_price_display_fields, enrich_order_display_fields, order_monitor_tpsl_needs_sync, ) -from wechat_notify_lib import build_wechat_rs_level_message, send_wechat_webhook -from hub_auth import request_allowed as hub_request_allowed -from instance_nav_lib import request_is_hub_soft_nav -from hub_volume_rank_lib import resolve_daily_volume_rank -from history_window_lib import ( +from lib.common.wechat_notify_lib import build_wechat_rs_level_message, send_wechat_webhook +from lib.hub.hub_auth import request_allowed as hub_request_allowed +from lib.instance.instance_nav_lib import request_is_hub_soft_nav +from lib.hub.hub_volume_rank_lib import resolve_daily_volume_rank +from lib.common.history_window_lib import ( PRESET_CUSTOM, PRESET_UTC_LAST24H, PRESET_UTC_LAST7D, @@ -201,8 +202,8 @@ from history_window_lib import ( utc_window_to_bj_sql_strings, utc_window_to_utc_sql_strings, ) -from trade_result_lib import count_winning_trades, normalize_result_with_pnl -from trade_exchange_stats_lib import attach_exchange_stats_to_trade, filter_position_lifecycle_fills +from lib.trade.trade_result_lib import count_winning_trades, normalize_result_with_pnl +from lib.trade.trade_exchange_stats_lib import attach_exchange_stats_to_trade, filter_position_lifecycle_fills def load_env_file(path): @@ -343,7 +344,7 @@ ORDER_CHART_ENABLED = os.getenv("ORDER_CHART_ENABLED", "true").lower() == "true" ORDER_CHART_TFS = [x.strip() for x in (os.getenv("ORDER_CHART_TFS", "4h,1h,15m,5m") or "").split(",") if x.strip()] ORDER_CHART_LIMIT = int(os.getenv("ORDER_CHART_LIMIT", "100")) ORDER_CHART_DIR = resolve_path(os.getenv("ORDER_CHART_DIR", "static/images/order_charts")) -from daily_open_limit_lib import ( +from lib.trade.daily_open_limit_lib import ( build_daily_open_alert_prompt, can_trade_new_open, check_daily_open_hard_limit, @@ -1506,10 +1507,10 @@ def init_db(): close_reason TEXT, closed_at TEXT)""" ) - from strategy_db import init_strategy_tables + from lib.strategy.strategy_db import init_strategy_tables init_strategy_tables(conn) - from account_risk_lib import ensure_account_risk_schema + from lib.trade.account_risk_lib import ensure_account_risk_schema ensure_account_risk_schema(conn) backfill_missing_key_signal_types(conn, monitor_type=ORDER_MONITOR_TYPE_KEY_AUTO) @@ -1546,7 +1547,7 @@ def get_db(): def hub_account_risk_status(conn): - from account_risk_lib import ( + from lib.trade.account_risk_lib import ( apply_position_limit_risk, compute_account_risk_status, enrich_risk_status_countdown, @@ -1562,7 +1563,7 @@ def hub_account_risk_status(conn): fmt_local_ms=ms_to_app_local_str, ) st = enrich_risk_status_countdown(st, now=now, daily_reset_hour=TRADING_DAY_RESET_HOUR) - from strategy_trade_labels import count_position_limit_active_monitors + from lib.strategy.strategy_trade_labels import count_position_limit_active_monitors return apply_position_limit_risk( st, @@ -1579,7 +1580,7 @@ def hub_user_initiated_close( trade_record_id=None, closed_at_ms=None, ): - from account_risk_lib import CLOSE_SOURCE_USER_HUB, on_user_initiated_close + from lib.trade.account_risk_lib import CLOSE_SOURCE_USER_HUB, on_user_initiated_close src = (source or "").strip() or CLOSE_SOURCE_USER_HUB on_user_initiated_close( @@ -2072,7 +2073,7 @@ def get_effective_trade_field(row, reviewed_key, base_key, default=None): def to_effective_trade_dict(row): item = row_to_dict(row) - from order_monitor_display_lib import snapshot_stop_loss + from lib.trade.order_monitor_display_lib import snapshot_stop_loss open_stop = snapshot_stop_loss(item.get("initial_stop_loss"), item.get("stop_loss")) item["display_open_stop_loss"] = open_stop @@ -2370,7 +2371,7 @@ def insert_trade_record( open_ts_ms = _to_ms_with_fallback(opened_at_ms, open_ts) close_ts_ms = _to_ms_with_fallback(closed_at_ms, close_ts) kst = key_signal_type_for_trade_record(key_signal_type, KEY_MONITOR_AUTO_TYPES) - from order_monitor_display_lib import snapshot_stop_loss + from lib.trade.order_monitor_display_lib import snapshot_stop_loss snap_sl = snapshot_stop_loss(initial_stop_loss, stop_loss) er = ( @@ -2761,7 +2762,7 @@ def get_exchange_capitals(force=False): def execute_transfer_usdt(amount, from_account, to_account): - from gate_transfer_lib import execute_transfer_usdt as _gate_execute_transfer_usdt + from lib.exchange.gate_transfer_lib import execute_transfer_usdt as _gate_execute_transfer_usdt return _gate_execute_transfer_usdt( exchange, @@ -2794,7 +2795,7 @@ def get_account_usdt_total(account_type): def _auto_transfer_active_count(conn): - from gate_transfer_lib import count_auto_transfer_blockers + from lib.exchange.gate_transfer_lib import count_auto_transfer_blockers return count_auto_transfer_blockers(conn, count_order_monitors=get_active_position_count) @@ -2878,7 +2879,7 @@ def resolve_capital_base_for_key_open(conn, trading_day, live_capital): def precheck_risk(conn, symbol, direction): now = app_now() - from account_risk_lib import account_risk_blocks_trading + from lib.trade.account_risk_lib import account_risk_blocks_trading ok_risk, risk_reason = account_risk_blocks_trading( conn, @@ -2890,7 +2891,7 @@ def precheck_risk(conn, symbol, direction): return False, risk_reason if not trading_day_reset_allows_new_open(now): return False, f"北京时间 {TRADING_DAY_RESET_HOUR}:00 前不允许持仓" - from account_risk_lib import position_limit_reached + from lib.trade.account_risk_lib import position_limit_reached reached, active_count, mx = position_limit_reached(conn, max_active_positions=MAX_ACTIVE_POSITIONS) if reached: @@ -3670,7 +3671,7 @@ def parse_ccxt_position_metrics(position, order_leverage=None): cs = float(get_contract_size(sym)) if sym else 1.0 except Exception: cs = 1.0 - from hub_position_metrics import enrich_ccxt_position_metrics_out + from lib.hub.hub_position_metrics import enrich_ccxt_position_metrics_out enrich_ccxt_position_metrics_out(p, out, contract_size=cs, funds_decimals=2) return out or None @@ -3854,7 +3855,7 @@ def fetch_latest_closing_fill(exchange_symbol, direction, opened_at_str, opened_ except Exception: pass try: - from gate_position_history_lib import pick_gate_position_close + from lib.exchange.gate_position_history_lib import pick_gate_position_close pos = pick_gate_position_close( fetch_gate_positions_close_history(), @@ -4114,7 +4115,7 @@ def reconcile_hub_external_close(conn, symbol, direction): """中控市价全平后:立即同步匹配 order_monitor,并读 Gate 平仓历史。""" if not exchange_private_api_configured(): return {"ok": False, "msg": "未配置 GATE_API_KEY / GATE_API_SECRET", "synced": 0} - from gate_position_history_lib import unified_symbol_for_match + from lib.exchange.gate_position_history_lib import unified_symbol_for_match sym_u = unified_symbol_for_match(symbol) dir_l = (direction or "").strip().lower() @@ -6513,14 +6514,14 @@ def background_task(): check_trigger_entry_key_monitors() _roll_cfg = app.extensions.get("strategy_roll_cfg") if _roll_cfg: - from strategy_roll_monitor_lib import check_roll_monitors + from lib.strategy.strategy_roll_monitor_lib import check_roll_monitors check_roll_monitors(_roll_cfg) check_key_monitors() check_order_monitors() cfg = app.extensions.get("strategy_trend_cfg") if cfg: - from strategy_trend_register import check_trend_pullback_plans + from lib.strategy.strategy_trend_register import check_trend_pullback_plans check_trend_pullback_plans(cfg) except: @@ -6848,7 +6849,7 @@ def render_main_page(page="trade", embed_mode=None): conn = get_db() session_row = ensure_session(conn, trading_day) local_current_capital = float(session_row["current_capital"]) - from instance_embed_context_lib import ( + from lib.instance.instance_embed_context_lib import ( embed_render_plan, minimal_stats_bundle, trade_records_summary, @@ -6921,7 +6922,7 @@ def render_main_page(page="trade", embed_mode=None): records = [] total = miss_count = rate = occupied_miss_total = 0 active_count = len(order_list) - from strategy_trade_labels import count_position_limit_active_monitors + from lib.strategy.strategy_trade_labels import count_position_limit_active_monitors position_limit_count = count_position_limit_active_monitors(conn) opens_today = count_opens_for_trading_day(conn, trading_day) @@ -6952,7 +6953,7 @@ def render_main_page(page="trade", embed_mode=None): ) strategy_extra = {} if plan.strategy: - from strategy_ui import strategy_render_extras + from lib.strategy.strategy_ui import strategy_render_extras strategy_extra = strategy_render_extras( conn, @@ -6962,7 +6963,7 @@ def render_main_page(page="trade", embed_mode=None): trend_cfg=app.extensions.get("strategy_trend_cfg"), ) conn.close() - from instance_embed_lib import embed_context_extras + from lib.instance.instance_embed_lib import embed_context_extras template_ctx = dict( page=page, @@ -7104,7 +7105,7 @@ def api_account_snapshot(): funding_usdt = round(funding_capital, 2) if funding_capital is not None else None current_capital = round(trading_capital, 2) if trading_capital is not None else round(local_current_capital, 2) recommended_capital = round(float(get_recommended_capital(current_capital)), 2) - from strategy_trade_labels import count_position_limit_active_monitors + from lib.strategy.strategy_trade_labels import count_position_limit_active_monitors position_limit_count = count_position_limit_active_monitors(conn) opens_today = count_opens_for_trading_day(conn, trading_day) @@ -7414,7 +7415,7 @@ def api_price_snapshot(): pass conn.close() - from hub_position_metrics import build_position_marks_list + from lib.hub.hub_position_metrics import build_position_marks_list position_marks = build_position_marks_list( all_swap_positions, @@ -7647,7 +7648,7 @@ def api_order_kline(): "volume": float(bar[5]), }) - from focus_chart_lib import ( + from lib.instance.focus_chart_lib import ( build_order_kline_order_payload, load_swap_positions_for_order_kline, metrics_for_order_item, @@ -7675,7 +7676,7 @@ def api_order_kline(): ex_metrics=ex_metrics, ) - from focus_chart_lib import kline_api_price_fields + from lib.instance.focus_chart_lib import kline_api_price_fields price_fields = kline_api_price_fields( exchange, @@ -7792,7 +7793,7 @@ def api_key_kline(): "lower_pct": lower_pct, } - from focus_chart_lib import enrich_key_kline_response + from lib.instance.focus_chart_lib import enrich_key_kline_response price_display, key_info = enrich_key_kline_response( symbol=symbol, @@ -7801,7 +7802,7 @@ def api_key_kline(): format_price_fn=format_price_for_symbol, ) - from focus_chart_lib import kline_api_price_fields + from lib.instance.focus_chart_lib import kline_api_price_fields price_fields = kline_api_price_fields( exchange, @@ -8756,7 +8757,7 @@ def del_order(id): opened_at=opened_at, closed_at=closed_at, ) - from account_risk_lib import CLOSE_SOURCE_USER_INSTANCE, insert_trade_record_id, on_user_initiated_close + from lib.trade.account_risk_lib import CLOSE_SOURCE_USER_INSTANCE, insert_trade_record_id, on_user_initiated_close on_user_initiated_close( conn, @@ -8770,7 +8771,7 @@ def del_order(id): try: _rcfg = app.extensions.get("strategy_roll_cfg") if isinstance(_rcfg, dict): - from strategy_register import roll_sync_after_external_close + from lib.strategy.strategy_register import roll_sync_after_external_close roll_sync_after_external_close(_rcfg, conn, row["symbol"], row["direction"]) except Exception: @@ -8832,7 +8833,7 @@ def del_order(id): opened_at=opened_at, closed_at=closed_at, ) - from account_risk_lib import CLOSE_SOURCE_USER_INSTANCE, insert_trade_record_id, on_user_initiated_close + from lib.trade.account_risk_lib import CLOSE_SOURCE_USER_INSTANCE, insert_trade_record_id, on_user_initiated_close on_user_initiated_close( conn, @@ -8846,7 +8847,7 @@ def del_order(id): try: _rcfg = app.extensions.get("strategy_roll_cfg") if isinstance(_rcfg, dict): - from strategy_register import roll_sync_after_external_close + from lib.strategy.strategy_register import roll_sync_after_external_close roll_sync_after_external_close(_rcfg, conn, row["symbol"], row["direction"]) except Exception: @@ -9020,7 +9021,7 @@ def add_journal(): d.get("post_breakeven_stare"), d.get("new_trade_while_occupied"), d.get("note"), image_filename ) ) - from account_risk_lib import on_journal_saved + from lib.trade.account_risk_lib import on_journal_saved on_journal_saved( conn, @@ -9108,7 +9109,7 @@ def api_reviews(): return jsonify([row_to_dict(r) for r in rows]) -_REPO_STATIC_DIR = os.path.join(os.path.dirname(BASE_DIR), "static") +_REPO_STATIC_DIR = common_static_dir(os.path.dirname(BASE_DIR)) _AI_REVIEW_RENDER_JS = os.path.join(_REPO_STATIC_DIR, "ai_review_render.js") _FORM_SUBMIT_GUARD_JS = os.path.join(_REPO_STATIC_DIR, "form_submit_guard.js") _MANUAL_ORDER_RR_PREVIEW_JS = os.path.join(_REPO_STATIC_DIR, "manual_order_rr_preview.js") @@ -9342,7 +9343,7 @@ def api_trade_record_review_update(): tuple(base_params + [rec_id]), ) if reviewed_result == "手动平仓" and reviewed_miss_reason: - from account_risk_lib import apply_manual_close_journal_cooloff + from lib.trade.account_risk_lib import apply_manual_close_journal_cooloff apply_manual_close_journal_cooloff( conn, @@ -9491,7 +9492,7 @@ def _hub_account_bundle(): def _hub_fetch_market(base=""): - from hub_market_info_lib import fetch_usdt_swap_market_info + from lib.hub.hub_market_info_lib import fetch_usdt_swap_market_info return fetch_usdt_swap_market_info( base_or_symbol=base, @@ -9504,7 +9505,7 @@ def _hub_fetch_market(base=""): def _hub_fetch_ohlcv(symbol, timeframe, since_ms=None, limit=500): - from hub_ohlcv_lib import fetch_ohlcv_for_hub + from lib.hub.hub_ohlcv_lib import fetch_ohlcv_for_hub return fetch_ohlcv_for_hub( symbol=symbol, @@ -9520,7 +9521,7 @@ def _hub_fetch_ohlcv(symbol, timeframe, since_ms=None, limit=500): def _hub_fetch_volume_rank(top_n=20): - from hub_volume_rank_lib import fetch_usdt_swap_volume_rank + from lib.hub.hub_volume_rank_lib import fetch_usdt_swap_volume_rank return fetch_usdt_swap_volume_rank( exchange=exchange, @@ -9537,7 +9538,7 @@ try: _repo_root = Path(__file__).resolve().parent.parent if str(_repo_root) not in sys.path: sys.path.insert(0, str(_repo_root)) - from hub_bridge import install_on_app + from lib.hub.hub_bridge import install_on_app install_on_app( app, @@ -9581,8 +9582,8 @@ def strategy_roll_page(): return redirect("/strategy") -from strategy_register import install_strategy_trading -from strategy_trend_register import install_strategy_trend +from lib.strategy.strategy_register import install_strategy_trading +from lib.strategy.strategy_trend_register import install_strategy_trend install_strategy_trading(app, _REPO_ROOT, app_module=sys.modules[__name__]) install_strategy_trend(app, _REPO_ROOT, app_module=sys.modules[__name__]) diff --git a/crypto_monitor_gate_bot/app.py b/crypto_monitor_gate_bot/app.py index 0a698c0..8b9f158 100644 --- a/crypto_monitor_gate_bot/app.py +++ b/crypto_monitor_gate_bot/app.py @@ -34,14 +34,15 @@ import sys if _REPO_ROOT not in sys.path: sys.path.insert(0, _REPO_ROOT) -from ai_client import ai_generate, ai_review, ai_short_advice -from ai_review_lib import ( +from lib.paths import common_static_dir +from lib.ai.ai_client import ai_generate, ai_review, ai_short_advice +from lib.ai.ai_review_lib import ( build_journal_ai_chart_path, collect_images_for_ai_review, journal_row_lines_for_ai, ) -from form_submit_lib import check_duplicate_submit, submit_scope_add_key, submit_scope_add_order -from fib_key_monitor_lib import ( +from lib.common.form_submit_lib import check_duplicate_submit, submit_scope_add_key, submit_scope_add_order +from lib.key_monitor.fib_key_monitor_lib import ( FIB_KEY_MONITOR_TYPES, KEY_ENTRY_REASON_BY_SIGNAL, backfill_missing_key_signal_types, @@ -53,7 +54,7 @@ from fib_key_monitor_lib import ( key_signal_type_for_trade_record, stored_key_signal_type, ) -from false_breakout_key_monitor_lib import ( +from lib.key_monitor.false_breakout_key_monitor_lib import ( FALSE_BREAKOUT_MONITOR_TYPE, FALSE_BREAKOUT_VALIDITY_HOURS, calc_false_breakout_plan, @@ -66,7 +67,7 @@ from false_breakout_key_monitor_lib import ( normalize_false_breakout_symbol, storage_bounds_from_key_price, ) -from strategy_trade_labels import ( +from lib.strategy.strategy_trade_labels import ( STRATEGY_ENTRY_REASON_OPTIONS, apply_order_monitor_source_labels, entry_reason_for_monitor_type, @@ -75,7 +76,7 @@ from strategy_trade_labels import ( trade_record_monitor_type as resolve_trade_record_monitor_type, trend_plan_id_from_monitor_row, ) -from journal_chart_lib import ( +from lib.instance.journal_chart_lib import ( JOURNAL_CHART_DEFAULT_LIMIT, JOURNAL_CHART_DEFAULT_TF1, JOURNAL_CHART_DEFAULT_TF2, @@ -91,7 +92,7 @@ from journal_chart_lib import ( trade_review_fetch_window, trim_rows_for_trade_review, ) -from key_sl_tp_lib import ( +from lib.key_monitor.key_sl_tp_lib import ( breakeven_enabled_from_row, normalize_sl_tp_mode, parse_breakeven_enabled_form, @@ -100,7 +101,7 @@ from key_sl_tp_lib import ( sl_tp_mode_label, sl_tp_plan_summary_text, ) -from time_close_lib import ( +from lib.trade.time_close_lib import ( TIME_CLOSE_RESULT, apply_time_close_to_payload, ensure_time_close_schema, @@ -111,13 +112,13 @@ from time_close_lib import ( time_close_label, time_close_settings_from_row, ) -from manual_sltp_lib import ( +from lib.trade.manual_sltp_lib import ( normalize_open_sltp_mode, resolve_entrust_sltp_prices, resolve_open_sltp_prices, ) -from key_monitor_schema_lib import ensure_key_monitor_schema -from trigger_entry_key_monitor_lib import ( +from lib.key_monitor.key_monitor_schema_lib import ensure_key_monitor_schema +from lib.key_monitor.trigger_entry_key_monitor_lib import ( BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE, CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE, TRIGGER_ENTRY_CLOSE_EXCHANGE_FAILED, @@ -140,7 +141,7 @@ from trigger_entry_key_monitor_lib import ( validate_trigger_entry_geometry, validate_trigger_entry_rr, ) -from position_sizing_lib import ( +from lib.trade.position_sizing_lib import ( OPEN_SOURCE_KEY_AUTO, OPEN_SOURCE_KEY_TRIGGER, OPEN_SOURCE_MANUAL, @@ -154,12 +155,12 @@ from position_sizing_lib import ( mode_label_zh, risk_percent_for_storage, ) -from key_monitor_full_margin_lib import ( +from lib.key_monitor.key_monitor_full_margin_lib import ( monitor_type_disallowed_in_full_margin, purge_disallowed_key_monitors, ) -from auto_transfer_daily_lib import run_auto_transfer_once_per_day -from key_monitor_lib import ( +from lib.common.auto_transfer_daily_lib import run_auto_transfer_once_per_day +from lib.key_monitor.key_monitor_lib import ( KEY_DIRECTION_WATCH, KEY_MONITOR_ALERT_ONLY_TYPES, KEY_MONITOR_AUTO_TYPES, @@ -179,16 +180,16 @@ from key_monitor_lib import ( rs_break_from_direction, run_rs_level_alert_tick, ) -from order_monitor_display_lib import ( +from lib.trade.order_monitor_display_lib import ( apply_order_price_display_fields, enrich_order_display_fields, order_monitor_tpsl_needs_sync, ) -from wechat_notify_lib import build_wechat_rs_level_message, send_wechat_webhook -from hub_auth import request_allowed as hub_request_allowed -from instance_nav_lib import request_is_hub_soft_nav -from hub_volume_rank_lib import resolve_daily_volume_rank -from history_window_lib import ( +from lib.common.wechat_notify_lib import build_wechat_rs_level_message, send_wechat_webhook +from lib.hub.hub_auth import request_allowed as hub_request_allowed +from lib.instance.instance_nav_lib import request_is_hub_soft_nav +from lib.hub.hub_volume_rank_lib import resolve_daily_volume_rank +from lib.common.history_window_lib import ( PRESET_CUSTOM, PRESET_UTC_LAST24H, PRESET_UTC_LAST7D, @@ -201,8 +202,8 @@ from history_window_lib import ( utc_window_to_bj_sql_strings, utc_window_to_utc_sql_strings, ) -from trade_result_lib import count_winning_trades, normalize_result_with_pnl -from trade_exchange_stats_lib import attach_exchange_stats_to_trade, filter_position_lifecycle_fills +from lib.trade.trade_result_lib import count_winning_trades, normalize_result_with_pnl +from lib.trade.trade_exchange_stats_lib import attach_exchange_stats_to_trade, filter_position_lifecycle_fills def load_env_file(path): @@ -343,7 +344,7 @@ ORDER_CHART_ENABLED = os.getenv("ORDER_CHART_ENABLED", "true").lower() == "true" ORDER_CHART_TFS = [x.strip() for x in (os.getenv("ORDER_CHART_TFS", "4h,1h,15m,5m") or "").split(",") if x.strip()] ORDER_CHART_LIMIT = int(os.getenv("ORDER_CHART_LIMIT", "100")) ORDER_CHART_DIR = resolve_path(os.getenv("ORDER_CHART_DIR", "static/images/order_charts")) -from daily_open_limit_lib import ( +from lib.trade.daily_open_limit_lib import ( build_daily_open_alert_prompt, can_trade_new_open, check_daily_open_hard_limit, @@ -1506,10 +1507,10 @@ def init_db(): close_reason TEXT, closed_at TEXT)""" ) - from strategy_db import init_strategy_tables + from lib.strategy.strategy_db import init_strategy_tables init_strategy_tables(conn) - from account_risk_lib import ensure_account_risk_schema + from lib.trade.account_risk_lib import ensure_account_risk_schema ensure_account_risk_schema(conn) backfill_missing_key_signal_types(conn, monitor_type=ORDER_MONITOR_TYPE_KEY_AUTO) @@ -1546,7 +1547,7 @@ def get_db(): def hub_account_risk_status(conn): - from account_risk_lib import ( + from lib.trade.account_risk_lib import ( apply_position_limit_risk, compute_account_risk_status, enrich_risk_status_countdown, @@ -1562,7 +1563,7 @@ def hub_account_risk_status(conn): fmt_local_ms=ms_to_app_local_str, ) st = enrich_risk_status_countdown(st, now=now, daily_reset_hour=TRADING_DAY_RESET_HOUR) - from strategy_trade_labels import count_position_limit_active_monitors + from lib.strategy.strategy_trade_labels import count_position_limit_active_monitors return apply_position_limit_risk( st, @@ -1579,7 +1580,7 @@ def hub_user_initiated_close( trade_record_id=None, closed_at_ms=None, ): - from account_risk_lib import CLOSE_SOURCE_USER_HUB, on_user_initiated_close + from lib.trade.account_risk_lib import CLOSE_SOURCE_USER_HUB, on_user_initiated_close src = (source or "").strip() or CLOSE_SOURCE_USER_HUB on_user_initiated_close( @@ -2072,7 +2073,7 @@ def get_effective_trade_field(row, reviewed_key, base_key, default=None): def to_effective_trade_dict(row): item = row_to_dict(row) - from order_monitor_display_lib import snapshot_stop_loss + from lib.trade.order_monitor_display_lib import snapshot_stop_loss open_stop = snapshot_stop_loss(item.get("initial_stop_loss"), item.get("stop_loss")) item["display_open_stop_loss"] = open_stop @@ -2370,7 +2371,7 @@ def insert_trade_record( open_ts_ms = _to_ms_with_fallback(opened_at_ms, open_ts) close_ts_ms = _to_ms_with_fallback(closed_at_ms, close_ts) kst = key_signal_type_for_trade_record(key_signal_type, KEY_MONITOR_AUTO_TYPES) - from order_monitor_display_lib import snapshot_stop_loss + from lib.trade.order_monitor_display_lib import snapshot_stop_loss snap_sl = snapshot_stop_loss(initial_stop_loss, stop_loss) er = ( @@ -2761,7 +2762,7 @@ def get_exchange_capitals(force=False): def execute_transfer_usdt(amount, from_account, to_account): - from gate_transfer_lib import execute_transfer_usdt as _gate_execute_transfer_usdt + from lib.exchange.gate_transfer_lib import execute_transfer_usdt as _gate_execute_transfer_usdt return _gate_execute_transfer_usdt( exchange, @@ -2794,7 +2795,7 @@ def get_account_usdt_total(account_type): def _auto_transfer_active_count(conn): - from gate_transfer_lib import count_auto_transfer_blockers + from lib.exchange.gate_transfer_lib import count_auto_transfer_blockers return count_auto_transfer_blockers(conn, count_order_monitors=get_active_position_count) @@ -2878,7 +2879,7 @@ def resolve_capital_base_for_key_open(conn, trading_day, live_capital): def precheck_risk(conn, symbol, direction): now = app_now() - from account_risk_lib import account_risk_blocks_trading + from lib.trade.account_risk_lib import account_risk_blocks_trading ok_risk, risk_reason = account_risk_blocks_trading( conn, @@ -2890,7 +2891,7 @@ def precheck_risk(conn, symbol, direction): return False, risk_reason if not trading_day_reset_allows_new_open(now): return False, f"北京时间 {TRADING_DAY_RESET_HOUR}:00 前不允许持仓" - from account_risk_lib import position_limit_reached + from lib.trade.account_risk_lib import position_limit_reached reached, active_count, mx = position_limit_reached(conn, max_active_positions=MAX_ACTIVE_POSITIONS) if reached: @@ -3670,7 +3671,7 @@ def parse_ccxt_position_metrics(position, order_leverage=None): cs = float(get_contract_size(sym)) if sym else 1.0 except Exception: cs = 1.0 - from hub_position_metrics import enrich_ccxt_position_metrics_out + from lib.hub.hub_position_metrics import enrich_ccxt_position_metrics_out enrich_ccxt_position_metrics_out(p, out, contract_size=cs, funds_decimals=2) return out or None @@ -3854,7 +3855,7 @@ def fetch_latest_closing_fill(exchange_symbol, direction, opened_at_str, opened_ except Exception: pass try: - from gate_position_history_lib import pick_gate_position_close + from lib.exchange.gate_position_history_lib import pick_gate_position_close pos = pick_gate_position_close( fetch_gate_positions_close_history(), @@ -4114,7 +4115,7 @@ def reconcile_hub_external_close(conn, symbol, direction): """中控市价全平后:立即同步匹配 order_monitor,并读 Gate 平仓历史。""" if not exchange_private_api_configured(): return {"ok": False, "msg": "未配置 GATE_API_KEY / GATE_API_SECRET", "synced": 0} - from gate_position_history_lib import unified_symbol_for_match + from lib.exchange.gate_position_history_lib import unified_symbol_for_match sym_u = unified_symbol_for_match(symbol) dir_l = (direction or "").strip().lower() @@ -6513,14 +6514,14 @@ def background_task(): check_trigger_entry_key_monitors() _roll_cfg = app.extensions.get("strategy_roll_cfg") if _roll_cfg: - from strategy_roll_monitor_lib import check_roll_monitors + from lib.strategy.strategy_roll_monitor_lib import check_roll_monitors check_roll_monitors(_roll_cfg) check_key_monitors() check_order_monitors() cfg = app.extensions.get("strategy_trend_cfg") if cfg: - from strategy_trend_register import check_trend_pullback_plans + from lib.strategy.strategy_trend_register import check_trend_pullback_plans check_trend_pullback_plans(cfg) except: @@ -6848,7 +6849,7 @@ def render_main_page(page="trade", embed_mode=None): conn = get_db() session_row = ensure_session(conn, trading_day) local_current_capital = float(session_row["current_capital"]) - from instance_embed_context_lib import ( + from lib.instance.instance_embed_context_lib import ( embed_render_plan, minimal_stats_bundle, trade_records_summary, @@ -6921,7 +6922,7 @@ def render_main_page(page="trade", embed_mode=None): records = [] total = miss_count = rate = occupied_miss_total = 0 active_count = len(order_list) - from strategy_trade_labels import count_position_limit_active_monitors + from lib.strategy.strategy_trade_labels import count_position_limit_active_monitors position_limit_count = count_position_limit_active_monitors(conn) opens_today = count_opens_for_trading_day(conn, trading_day) @@ -6952,7 +6953,7 @@ def render_main_page(page="trade", embed_mode=None): ) strategy_extra = {} if plan.strategy: - from strategy_ui import strategy_render_extras + from lib.strategy.strategy_ui import strategy_render_extras strategy_extra = strategy_render_extras( conn, @@ -6962,7 +6963,7 @@ def render_main_page(page="trade", embed_mode=None): trend_cfg=app.extensions.get("strategy_trend_cfg"), ) conn.close() - from instance_embed_lib import embed_context_extras + from lib.instance.instance_embed_lib import embed_context_extras template_ctx = dict( page=page, @@ -7100,7 +7101,7 @@ def api_account_snapshot(): funding_usdt = round(funding_capital, 2) if funding_capital is not None else None current_capital = round(trading_capital, 2) if trading_capital is not None else round(local_current_capital, 2) recommended_capital = round(float(get_recommended_capital(current_capital)), 2) - from strategy_trade_labels import count_position_limit_active_monitors + from lib.strategy.strategy_trade_labels import count_position_limit_active_monitors position_limit_count = count_position_limit_active_monitors(conn) opens_today = count_opens_for_trading_day(conn, trading_day) @@ -7410,7 +7411,7 @@ def api_price_snapshot(): pass conn.close() - from hub_position_metrics import build_position_marks_list + from lib.hub.hub_position_metrics import build_position_marks_list position_marks = build_position_marks_list( all_swap_positions, @@ -7643,7 +7644,7 @@ def api_order_kline(): "volume": float(bar[5]), }) - from focus_chart_lib import ( + from lib.instance.focus_chart_lib import ( build_order_kline_order_payload, load_swap_positions_for_order_kline, metrics_for_order_item, @@ -7671,7 +7672,7 @@ def api_order_kline(): ex_metrics=ex_metrics, ) - from focus_chart_lib import kline_api_price_fields + from lib.instance.focus_chart_lib import kline_api_price_fields price_fields = kline_api_price_fields( exchange, @@ -7788,7 +7789,7 @@ def api_key_kline(): "lower_pct": lower_pct, } - from focus_chart_lib import enrich_key_kline_response + from lib.instance.focus_chart_lib import enrich_key_kline_response price_display, key_info = enrich_key_kline_response( symbol=symbol, @@ -7797,7 +7798,7 @@ def api_key_kline(): format_price_fn=format_price_for_symbol, ) - from focus_chart_lib import kline_api_price_fields + from lib.instance.focus_chart_lib import kline_api_price_fields price_fields = kline_api_price_fields( exchange, @@ -8752,7 +8753,7 @@ def del_order(id): opened_at=opened_at, closed_at=closed_at, ) - from account_risk_lib import CLOSE_SOURCE_USER_INSTANCE, insert_trade_record_id, on_user_initiated_close + from lib.trade.account_risk_lib import CLOSE_SOURCE_USER_INSTANCE, insert_trade_record_id, on_user_initiated_close on_user_initiated_close( conn, @@ -8766,7 +8767,7 @@ def del_order(id): try: _rcfg = app.extensions.get("strategy_roll_cfg") if isinstance(_rcfg, dict): - from strategy_register import roll_sync_after_external_close + from lib.strategy.strategy_register import roll_sync_after_external_close roll_sync_after_external_close(_rcfg, conn, row["symbol"], row["direction"]) except Exception: @@ -8828,7 +8829,7 @@ def del_order(id): opened_at=opened_at, closed_at=closed_at, ) - from account_risk_lib import CLOSE_SOURCE_USER_INSTANCE, insert_trade_record_id, on_user_initiated_close + from lib.trade.account_risk_lib import CLOSE_SOURCE_USER_INSTANCE, insert_trade_record_id, on_user_initiated_close on_user_initiated_close( conn, @@ -8842,7 +8843,7 @@ def del_order(id): try: _rcfg = app.extensions.get("strategy_roll_cfg") if isinstance(_rcfg, dict): - from strategy_register import roll_sync_after_external_close + from lib.strategy.strategy_register import roll_sync_after_external_close roll_sync_after_external_close(_rcfg, conn, row["symbol"], row["direction"]) except Exception: @@ -9016,7 +9017,7 @@ def add_journal(): d.get("post_breakeven_stare"), d.get("new_trade_while_occupied"), d.get("note"), image_filename ) ) - from account_risk_lib import on_journal_saved + from lib.trade.account_risk_lib import on_journal_saved on_journal_saved( conn, @@ -9104,7 +9105,7 @@ def api_reviews(): return jsonify([row_to_dict(r) for r in rows]) -_REPO_STATIC_DIR = os.path.join(os.path.dirname(BASE_DIR), "static") +_REPO_STATIC_DIR = common_static_dir(os.path.dirname(BASE_DIR)) _AI_REVIEW_RENDER_JS = os.path.join(_REPO_STATIC_DIR, "ai_review_render.js") _FORM_SUBMIT_GUARD_JS = os.path.join(_REPO_STATIC_DIR, "form_submit_guard.js") _MANUAL_ORDER_RR_PREVIEW_JS = os.path.join(_REPO_STATIC_DIR, "manual_order_rr_preview.js") @@ -9338,7 +9339,7 @@ def api_trade_record_review_update(): tuple(base_params + [rec_id]), ) if reviewed_result == "手动平仓" and reviewed_miss_reason: - from account_risk_lib import apply_manual_close_journal_cooloff + from lib.trade.account_risk_lib import apply_manual_close_journal_cooloff apply_manual_close_journal_cooloff( conn, @@ -9487,7 +9488,7 @@ def _hub_account_bundle(): def _hub_fetch_market(base=""): - from hub_market_info_lib import fetch_usdt_swap_market_info + from lib.hub.hub_market_info_lib import fetch_usdt_swap_market_info return fetch_usdt_swap_market_info( base_or_symbol=base, @@ -9500,7 +9501,7 @@ def _hub_fetch_market(base=""): def _hub_fetch_ohlcv(symbol, timeframe, since_ms=None, limit=500): - from hub_ohlcv_lib import fetch_ohlcv_for_hub + from lib.hub.hub_ohlcv_lib import fetch_ohlcv_for_hub return fetch_ohlcv_for_hub( symbol=symbol, @@ -9516,7 +9517,7 @@ def _hub_fetch_ohlcv(symbol, timeframe, since_ms=None, limit=500): def _hub_fetch_volume_rank(top_n=20): - from hub_volume_rank_lib import fetch_usdt_swap_volume_rank + from lib.hub.hub_volume_rank_lib import fetch_usdt_swap_volume_rank return fetch_usdt_swap_volume_rank( exchange=exchange, @@ -9533,7 +9534,7 @@ try: _repo_root = Path(__file__).resolve().parent.parent if str(_repo_root) not in sys.path: sys.path.insert(0, str(_repo_root)) - from hub_bridge import install_on_app + from lib.hub.hub_bridge import install_on_app install_on_app( app, @@ -9577,8 +9578,8 @@ def strategy_roll_page(): return redirect("/strategy") -from strategy_register import install_strategy_trading -from strategy_trend_register import install_strategy_trend +from lib.strategy.strategy_register import install_strategy_trading +from lib.strategy.strategy_trend_register import install_strategy_trend install_strategy_trading(app, _REPO_ROOT, app_module=sys.modules[__name__]) install_strategy_trend(app, _REPO_ROOT, app_module=sys.modules[__name__]) diff --git a/crypto_monitor_okx/app.py b/crypto_monitor_okx/app.py index 82e50ff..7aed8da 100644 --- a/crypto_monitor_okx/app.py +++ b/crypto_monitor_okx/app.py @@ -34,14 +34,15 @@ import sys if _REPO_ROOT not in sys.path: sys.path.insert(0, _REPO_ROOT) -from ai_client import ai_generate, ai_review, ai_short_advice -from ai_review_lib import ( +from lib.paths import common_static_dir +from lib.ai.ai_client import ai_generate, ai_review, ai_short_advice +from lib.ai.ai_review_lib import ( build_journal_ai_chart_path, collect_images_for_ai_review, journal_row_lines_for_ai, ) -from form_submit_lib import check_duplicate_submit, submit_scope_add_key, submit_scope_add_order -from fib_key_monitor_lib import ( +from lib.common.form_submit_lib import check_duplicate_submit, submit_scope_add_key, submit_scope_add_order +from lib.key_monitor.fib_key_monitor_lib import ( FIB_KEY_MONITOR_TYPES, backfill_missing_key_signal_types, calc_fib_plan, @@ -52,7 +53,7 @@ from fib_key_monitor_lib import ( key_signal_type_for_trade_record, stored_key_signal_type, ) -from false_breakout_key_monitor_lib import ( +from lib.key_monitor.false_breakout_key_monitor_lib import ( FALSE_BREAKOUT_MONITOR_TYPE, FALSE_BREAKOUT_VALIDITY_HOURS, calc_false_breakout_plan, @@ -65,7 +66,7 @@ from false_breakout_key_monitor_lib import ( normalize_false_breakout_symbol, storage_bounds_from_key_price, ) -from strategy_trade_labels import ( +from lib.strategy.strategy_trade_labels import ( STRATEGY_ENTRY_REASON_OPTIONS, apply_order_monitor_source_labels, entry_reason_for_monitor_type, @@ -74,8 +75,8 @@ from strategy_trade_labels import ( trade_record_monitor_type as resolve_trade_record_monitor_type, trend_plan_id_from_monitor_row, ) -from okx_orders_lib import cancel_okx_all_open_orders, fetch_okx_all_open_orders -from journal_chart_lib import ( +from lib.exchange.okx_orders_lib import cancel_okx_all_open_orders, fetch_okx_all_open_orders +from lib.instance.journal_chart_lib import ( JOURNAL_CHART_DEFAULT_LIMIT, JOURNAL_CHART_DEFAULT_TF1, JOURNAL_CHART_DEFAULT_TF2, @@ -91,7 +92,7 @@ from journal_chart_lib import ( trade_review_fetch_window, trim_rows_for_trade_review, ) -from key_sl_tp_lib import ( +from lib.key_monitor.key_sl_tp_lib import ( breakeven_enabled_from_row, normalize_sl_tp_mode, parse_breakeven_enabled_form, @@ -100,7 +101,7 @@ from key_sl_tp_lib import ( sl_tp_mode_label, sl_tp_plan_summary_text, ) -from time_close_lib import ( +from lib.trade.time_close_lib import ( TIME_CLOSE_RESULT, apply_time_close_to_payload, ensure_time_close_schema, @@ -111,13 +112,13 @@ from time_close_lib import ( time_close_label, time_close_settings_from_row, ) -from manual_sltp_lib import ( +from lib.trade.manual_sltp_lib import ( normalize_open_sltp_mode, resolve_entrust_sltp_prices, resolve_open_sltp_prices, ) -from key_monitor_schema_lib import ensure_key_monitor_schema -from trigger_entry_key_monitor_lib import ( +from lib.key_monitor.key_monitor_schema_lib import ensure_key_monitor_schema +from lib.key_monitor.trigger_entry_key_monitor_lib import ( BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE, CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE, TRIGGER_ENTRY_CLOSE_EXCHANGE_FAILED, @@ -140,7 +141,7 @@ from trigger_entry_key_monitor_lib import ( validate_trigger_entry_geometry, validate_trigger_entry_rr, ) -from position_sizing_lib import ( +from lib.trade.position_sizing_lib import ( OPEN_SOURCE_KEY_AUTO, OPEN_SOURCE_MANUAL, assert_open_source_allowed, @@ -153,12 +154,12 @@ from position_sizing_lib import ( mode_label_zh, risk_percent_for_storage, ) -from key_monitor_full_margin_lib import ( +from lib.key_monitor.key_monitor_full_margin_lib import ( monitor_type_disallowed_in_full_margin, purge_disallowed_key_monitors, ) -from auto_transfer_daily_lib import run_auto_transfer_once_per_day -from key_monitor_lib import ( +from lib.common.auto_transfer_daily_lib import run_auto_transfer_once_per_day +from lib.key_monitor.key_monitor_lib import ( KEY_DIRECTION_WATCH, KEY_MONITOR_ALERT_ONLY_TYPES, KEY_MONITOR_AUTO_TYPES, @@ -178,16 +179,16 @@ from key_monitor_lib import ( rs_break_from_direction, run_rs_level_alert_tick, ) -from order_monitor_display_lib import ( +from lib.trade.order_monitor_display_lib import ( apply_order_price_display_fields, enrich_order_display_fields, order_monitor_tpsl_needs_sync, ) -from wechat_notify_lib import build_wechat_rs_level_message, send_wechat_webhook -from hub_auth import request_allowed as hub_request_allowed -from instance_nav_lib import request_is_hub_soft_nav -from hub_volume_rank_lib import resolve_daily_volume_rank -from history_window_lib import ( +from lib.common.wechat_notify_lib import build_wechat_rs_level_message, send_wechat_webhook +from lib.hub.hub_auth import request_allowed as hub_request_allowed +from lib.instance.instance_nav_lib import request_is_hub_soft_nav +from lib.hub.hub_volume_rank_lib import resolve_daily_volume_rank +from lib.common.history_window_lib import ( PRESET_CUSTOM, PRESET_UTC_LAST24H, PRESET_UTC_LAST7D, @@ -200,8 +201,8 @@ from history_window_lib import ( utc_window_to_bj_sql_strings, utc_window_to_utc_sql_strings, ) -from trade_result_lib import count_winning_trades, normalize_result_with_pnl -from trade_exchange_stats_lib import attach_exchange_stats_to_trade, filter_position_lifecycle_fills +from lib.trade.trade_result_lib import count_winning_trades, normalize_result_with_pnl +from lib.trade.trade_exchange_stats_lib import attach_exchange_stats_to_trade, filter_position_lifecycle_fills def load_env_file(path): @@ -323,7 +324,7 @@ ORDER_CHART_ENABLED = os.getenv("ORDER_CHART_ENABLED", "true").lower() == "true" ORDER_CHART_TFS = [x.strip() for x in (os.getenv("ORDER_CHART_TFS", "4h,1h,15m,5m") or "").split(",") if x.strip()] ORDER_CHART_LIMIT = int(os.getenv("ORDER_CHART_LIMIT", "100")) ORDER_CHART_DIR = resolve_path(os.getenv("ORDER_CHART_DIR", "static/images/order_charts")) -from daily_open_limit_lib import ( +from lib.trade.daily_open_limit_lib import ( build_daily_open_alert_prompt, can_trade_new_open, check_daily_open_hard_limit, @@ -1493,10 +1494,10 @@ def init_db(): close_reason TEXT, closed_at TEXT)""" ) - from strategy_db import init_strategy_tables + from lib.strategy.strategy_db import init_strategy_tables init_strategy_tables(conn) - from account_risk_lib import ensure_account_risk_schema + from lib.trade.account_risk_lib import ensure_account_risk_schema ensure_account_risk_schema(conn) backfill_missing_key_signal_types(conn, monitor_type=ORDER_MONITOR_TYPE_KEY_AUTO) @@ -1533,7 +1534,7 @@ def get_db(): def hub_account_risk_status(conn): - from account_risk_lib import ( + from lib.trade.account_risk_lib import ( apply_position_limit_risk, compute_account_risk_status, enrich_risk_status_countdown, @@ -1549,7 +1550,7 @@ def hub_account_risk_status(conn): fmt_local_ms=ms_to_app_local_str, ) st = enrich_risk_status_countdown(st, now=now, daily_reset_hour=TRADING_DAY_RESET_HOUR) - from strategy_trade_labels import count_position_limit_active_monitors + from lib.strategy.strategy_trade_labels import count_position_limit_active_monitors return apply_position_limit_risk( st, @@ -1566,7 +1567,7 @@ def hub_user_initiated_close( trade_record_id=None, closed_at_ms=None, ): - from account_risk_lib import CLOSE_SOURCE_USER_HUB, on_user_initiated_close + from lib.trade.account_risk_lib import CLOSE_SOURCE_USER_HUB, on_user_initiated_close src = (source or "").strip() or CLOSE_SOURCE_USER_HUB on_user_initiated_close( @@ -2021,7 +2022,7 @@ def get_effective_trade_field(row, reviewed_key, base_key, default=None): def to_effective_trade_dict(row): item = row_to_dict(row) - from order_monitor_display_lib import snapshot_stop_loss + from lib.trade.order_monitor_display_lib import snapshot_stop_loss open_stop = snapshot_stop_loss(item.get("initial_stop_loss"), item.get("stop_loss")) item["display_open_stop_loss"] = open_stop @@ -2266,7 +2267,7 @@ def insert_trade_record( open_ts_ms = _to_ms_with_fallback(opened_at_ms, open_ts) close_ts_ms = _to_ms_with_fallback(closed_at_ms, close_ts) kst = key_signal_type_for_trade_record(key_signal_type, KEY_MONITOR_AUTO_TYPES) - from order_monitor_display_lib import snapshot_stop_loss + from lib.trade.order_monitor_display_lib import snapshot_stop_loss snap_sl = snapshot_stop_loss(initial_stop_loss, stop_loss) er = ( @@ -2590,7 +2591,7 @@ def trading_day_reset_allows_new_open(now, conn=None): def precheck_risk(conn, symbol, direction): now = app_now() - from account_risk_lib import account_risk_blocks_trading + from lib.trade.account_risk_lib import account_risk_blocks_trading ok_risk, risk_reason = account_risk_blocks_trading( conn, @@ -2602,7 +2603,7 @@ def precheck_risk(conn, symbol, direction): return False, risk_reason if not trading_day_reset_allows_new_open(now): return False, f"北京时间 {TRADING_DAY_RESET_HOUR}:00 前不允许持仓" - from account_risk_lib import position_limit_reached + from lib.trade.account_risk_lib import position_limit_reached reached, active_count, mx = position_limit_reached(conn, max_active_positions=MAX_ACTIVE_POSITIONS) if reached: @@ -2973,7 +2974,7 @@ def parse_ccxt_position_metrics(position, order_leverage=None): cs = float(get_contract_size(sym)) if sym else 1.0 except Exception: cs = 1.0 - from hub_position_metrics import enrich_ccxt_position_metrics_out + from lib.hub.hub_position_metrics import enrich_ccxt_position_metrics_out enrich_ccxt_position_metrics_out( p, out, contract_size=cs, funds_decimals=FUNDS_DECIMALS @@ -6252,14 +6253,14 @@ def background_task(): check_trigger_entry_key_monitors() _roll_cfg = app.extensions.get("strategy_roll_cfg") if _roll_cfg: - from strategy_roll_monitor_lib import check_roll_monitors + from lib.strategy.strategy_roll_monitor_lib import check_roll_monitors check_roll_monitors(_roll_cfg) check_key_monitors() check_order_monitors() cfg = app.extensions.get("strategy_trend_cfg") if cfg: - from strategy_trend_register import check_trend_pullback_plans + from lib.strategy.strategy_trend_register import check_trend_pullback_plans check_trend_pullback_plans(cfg) except: @@ -6348,7 +6349,7 @@ def render_main_page(page="trade", embed_mode=None): conn = get_db() session_row = ensure_session(conn, trading_day) local_current_capital = float(session_row["current_capital"]) - from instance_embed_context_lib import ( + from lib.instance.instance_embed_context_lib import ( embed_render_plan, minimal_stats_bundle, trade_records_summary, @@ -6420,7 +6421,7 @@ def render_main_page(page="trade", embed_mode=None): records = [] total = miss_count = rate = occupied_miss_total = 0 active_count = len(order_list) - from strategy_trade_labels import count_position_limit_active_monitors + from lib.strategy.strategy_trade_labels import count_position_limit_active_monitors position_limit_count = count_position_limit_active_monitors(conn) open_guard_enabled = get_trading_day_reset_open_guard_enabled(conn) @@ -6453,7 +6454,7 @@ def render_main_page(page="trade", embed_mode=None): ) strategy_extra = {} if plan.strategy: - from strategy_ui import strategy_render_extras + from lib.strategy.strategy_ui import strategy_render_extras strategy_extra = strategy_render_extras( conn, @@ -6463,7 +6464,7 @@ def render_main_page(page="trade", embed_mode=None): trend_cfg=app.extensions.get("strategy_trend_cfg"), ) conn.close() - from instance_embed_lib import embed_context_extras + from lib.instance.instance_embed_lib import embed_context_extras template_ctx = dict( page=page, @@ -6600,7 +6601,7 @@ def api_account_snapshot(): funding_usdt = round(funding_capital, FUNDS_DECIMALS) if funding_capital is not None else None current_capital = round(trading_capital, FUNDS_DECIMALS) if trading_capital is not None else round(local_current_capital, FUNDS_DECIMALS) recommended_capital = get_recommended_capital(current_capital) - from strategy_trade_labels import count_position_limit_active_monitors + from lib.strategy.strategy_trade_labels import count_position_limit_active_monitors position_limit_count = count_position_limit_active_monitors(conn) open_guard_enabled = get_trading_day_reset_open_guard_enabled(conn) @@ -6651,7 +6652,7 @@ def api_settings_open_guard(): now = app_now() conn = get_db() trading_day = get_trading_day(now) - from strategy_trade_labels import count_position_limit_active_monitors + from lib.strategy.strategy_trade_labels import count_position_limit_active_monitors position_limit_count = count_position_limit_active_monitors(conn) guard_on = get_trading_day_reset_open_guard_enabled(conn) @@ -6954,7 +6955,7 @@ def api_price_snapshot(): pass conn.close() - from hub_position_metrics import build_position_marks_list + from lib.hub.hub_position_metrics import build_position_marks_list position_marks = build_position_marks_list( all_swap_positions, @@ -7108,7 +7109,7 @@ def api_order_kline(): "volume": float(bar[5]), }) - from focus_chart_lib import ( + from lib.instance.focus_chart_lib import ( build_order_kline_order_payload, load_swap_positions_for_order_kline, metrics_for_order_item, @@ -7136,7 +7137,7 @@ def api_order_kline(): ex_metrics=ex_metrics, ) - from focus_chart_lib import kline_api_price_fields + from lib.instance.focus_chart_lib import kline_api_price_fields price_fields = kline_api_price_fields( exchange, @@ -7252,7 +7253,7 @@ def api_key_kline(): "lower_pct": lower_pct, } - from focus_chart_lib import enrich_key_kline_response + from lib.instance.focus_chart_lib import enrich_key_kline_response price_display, key_info = enrich_key_kline_response( symbol=symbol, @@ -7261,7 +7262,7 @@ def api_key_kline(): format_price_fn=format_price_for_symbol, ) - from focus_chart_lib import kline_api_price_fields + from lib.instance.focus_chart_lib import kline_api_price_fields price_fields = kline_api_price_fields( exchange, @@ -8241,7 +8242,7 @@ def del_order(id): opened_at=opened_at, closed_at=closed_at, ) - from account_risk_lib import CLOSE_SOURCE_USER_INSTANCE, insert_trade_record_id, on_user_initiated_close + from lib.trade.account_risk_lib import CLOSE_SOURCE_USER_INSTANCE, insert_trade_record_id, on_user_initiated_close on_user_initiated_close( conn, @@ -8255,7 +8256,7 @@ def del_order(id): try: _rcfg = app.extensions.get("strategy_roll_cfg") if isinstance(_rcfg, dict): - from strategy_register import roll_sync_after_external_close + from lib.strategy.strategy_register import roll_sync_after_external_close roll_sync_after_external_close(_rcfg, conn, row["symbol"], row["direction"]) except Exception: @@ -8314,7 +8315,7 @@ def del_order(id): opened_at=opened_at, closed_at=closed_at, ) - from account_risk_lib import CLOSE_SOURCE_USER_INSTANCE, insert_trade_record_id, on_user_initiated_close + from lib.trade.account_risk_lib import CLOSE_SOURCE_USER_INSTANCE, insert_trade_record_id, on_user_initiated_close on_user_initiated_close( conn, @@ -8328,7 +8329,7 @@ def del_order(id): try: _rcfg = app.extensions.get("strategy_roll_cfg") if isinstance(_rcfg, dict): - from strategy_register import roll_sync_after_external_close + from lib.strategy.strategy_register import roll_sync_after_external_close roll_sync_after_external_close(_rcfg, conn, row["symbol"], row["direction"]) except Exception: @@ -8489,7 +8490,7 @@ def add_journal(): d.get("post_breakeven_stare"), d.get("new_trade_while_occupied"), d.get("note"), image_filename ) ) - from account_risk_lib import on_journal_saved + from lib.trade.account_risk_lib import on_journal_saved on_journal_saved( conn, @@ -8577,7 +8578,7 @@ def api_reviews(): return jsonify([row_to_dict(r) for r in rows]) -_REPO_STATIC_DIR = os.path.join(os.path.dirname(BASE_DIR), "static") +_REPO_STATIC_DIR = common_static_dir(os.path.dirname(BASE_DIR)) _AI_REVIEW_RENDER_JS = os.path.join(_REPO_STATIC_DIR, "ai_review_render.js") _FORM_SUBMIT_GUARD_JS = os.path.join(_REPO_STATIC_DIR, "form_submit_guard.js") _MANUAL_ORDER_RR_PREVIEW_JS = os.path.join(_REPO_STATIC_DIR, "manual_order_rr_preview.js") @@ -8802,7 +8803,7 @@ def api_trade_record_review_update(): tuple(base_params + [rec_id]), ) if reviewed_result == "手动平仓" and reviewed_miss_reason: - from account_risk_lib import apply_manual_close_journal_cooloff + from lib.trade.account_risk_lib import apply_manual_close_journal_cooloff apply_manual_close_journal_cooloff( conn, @@ -8952,7 +8953,7 @@ def _hub_account_bundle(): def _hub_fetch_market(base=""): - from hub_market_info_lib import fetch_usdt_swap_market_info + from lib.hub.hub_market_info_lib import fetch_usdt_swap_market_info return fetch_usdt_swap_market_info( base_or_symbol=base, @@ -8965,7 +8966,7 @@ def _hub_fetch_market(base=""): def _hub_fetch_ohlcv(symbol, timeframe, since_ms=None, limit=500): - from hub_ohlcv_lib import fetch_ohlcv_for_hub + from lib.hub.hub_ohlcv_lib import fetch_ohlcv_for_hub return fetch_ohlcv_for_hub( symbol=symbol, @@ -8981,7 +8982,7 @@ def _hub_fetch_ohlcv(symbol, timeframe, since_ms=None, limit=500): def _hub_fetch_volume_rank(top_n=20): - from hub_volume_rank_lib import fetch_usdt_swap_volume_rank + from lib.hub.hub_volume_rank_lib import fetch_usdt_swap_volume_rank return fetch_usdt_swap_volume_rank( exchange=exchange, @@ -8998,7 +8999,7 @@ try: _repo_root = Path(__file__).resolve().parent.parent if str(_repo_root) not in sys.path: sys.path.insert(0, str(_repo_root)) - from hub_bridge import install_on_app + from lib.hub.hub_bridge import install_on_app install_on_app( app, @@ -9045,8 +9046,8 @@ def strategy_roll_page(): normalize_exchange_symbol = normalize_okx_symbol ensure_exchange_live_ready = ensure_okx_live_ready -from strategy_register import install_strategy_trading -from strategy_trend_register import install_strategy_trend +from lib.strategy.strategy_register import install_strategy_trading +from lib.strategy.strategy_trend_register import install_strategy_trend install_strategy_trading(app, _REPO_ROOT, app_module=sys.modules[__name__]) install_strategy_trend(app, _REPO_ROOT, app_module=sys.modules[__name__]) diff --git a/docs/lib-structure.md b/docs/lib-structure.md new file mode 100644 index 0000000..cc7ed08 --- /dev/null +++ b/docs/lib-structure.md @@ -0,0 +1,147 @@ +# lib/ 共用模块结构 + +四所实例与中控共用的 Python 库、模板与静态资源统一放在仓库根目录的 **`lib/`** 下。部署单元(`crypto_monitor_*`、`manual_trading_hub`)仍保持独立目录与 PM2 配置不变。 + +**重构前快照 Git 标签**:`pre-lib-modularization`(可用 `git checkout pre-lib-modularization` 查看旧布局)。 + +--- + +## 顶层目录 + +``` +crypto_monitor/ +├── crypto_monitor_binance/ # 四所:各自 app + .env + PM2 +├── crypto_monitor_gate/ +├── crypto_monitor_gate_bot/ +├── crypto_monitor_okx/ +├── manual_trading_hub/ # 中控 + 子代理 agent +│ +├── lib/ # 共用模块(本说明) +│ ├── strategy/ +│ ├── key_monitor/ +│ ├── trade/ +│ ├── hub/ +│ ├── ai/ +│ ├── instance/ +│ ├── exchange/ +│ ├── common/ +│ └── paths.py +│ +├── brand/ # 各所共用图标 +├── docs/ +├── deploy/ +├── scripts/ +├── tests/ +├── requirements.txt +└── README.md +``` + +--- + +## lib/ 子包说明 + +| 子包 | 职责 | 主要模块 | +|------|------|----------| +| **`lib/strategy/`** | 策略交易(顺势加仓、趋势回调、快照与记录) | `strategy_register.py`、`strategy_trend_register.py`、`strategy_db.py`、`strategy_roll_*`、`strategy_trend_*` | +| **`lib/strategy/templates/`** | 策略页 Jinja 模板(原 `strategy_templates/`) | `strategy_trading_page.html`、`strategy_roll_panel.html` 等 | +| **`lib/key_monitor/`** | 关键位监控、斐波、假突破、止盈止损方案 | `key_monitor_lib.py`、`fib_key_monitor_lib.py`、`key_sl_tp_lib.py` 等 | +| **`lib/trade/`** | 下单监控展示、计仓、账户风控、手动 SL/TP | `order_monitor_display_lib.py`、`position_sizing_lib.py`、`account_risk_lib.py` 等 | +| **`lib/hub/`** | 中控 API、K 线、归档、计仓器、SSO/Bridge | `hub_bridge.py`、`hub_kline_store.py`、`hub_trades_lib.py` 等 | +| **`lib/ai/`** | AI 复盘与文本生成 | `ai_client.py`、`ai_review_lib.py` | +| **`lib/instance/`** | 中控 iframe 嵌入、导航、复盘图表 | `instance_embed_lib.py`、`focus_chart_lib.py`、`journal_chart_lib.py` | +| **`lib/instance/templates/`** | 嵌入页片段(原 `embed_templates/`) | `embed_page_fragment.html` | +| **`lib/exchange/`** | 特定交易所工具 | `gate_transfer_lib.py`、`okx_orders_lib.py` 等 | +| **`lib/common/`** | 跨功能小工具 | `form_submit_lib.py`、`wechat_notify_lib.py` 等 | +| **`lib/common/static/`** | 四所与中控共用的 JS/CSS(原根目录 `static/`) | `instance_theme.js`、`strategy_roll.js` 等 | + +> **说明**:`hub_*` 命名表示「中控侧能力或行情聚合」,但部分模块(如 `hub_volume_rank_lib`、`hub_market_info_lib`)四所 `app.py` 也会调用,并非中控独占。 + +--- + +## 路径辅助函数 + +`lib/paths.py` 集中维护资源目录,避免硬编码: + +```python +from lib.paths import strategy_templates_dir, embed_templates_dir, common_static_dir + +strategy_templates_dir() # .../lib/strategy/templates +embed_templates_dir() # .../lib/instance/templates +common_static_dir() # .../lib/common/static +``` + +可选传入 `repo_root`(字符串或 `Path`),默认使用 `lib/` 的上级目录即仓库根。 + +--- + +## Python 导入约定 + +各部署目录在启动时将 **仓库根** 加入 `sys.path`(与重构前相同): + +```python +_REPO_ROOT = os.path.dirname(BASE_DIR) # 或 Path(__file__).resolve().parent.parent +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) +``` + +之后使用 **`lib.<子包>.<模块>`** 形式导入,例如: + +```python +from lib.strategy.strategy_db import init_strategy_tables +from lib.key_monitor.key_monitor_lib import check_key_monitors +from lib.hub.hub_bridge import install_on_app +from lib.ai.ai_client import ai_review +``` + +策略注册仍在各所 `app.py` 末尾: + +```python +from lib.strategy.strategy_register import install_strategy_trading +from lib.strategy.strategy_trend_register import install_strategy_trend + +install_strategy_trading(app, _REPO_ROOT, app_module=sys.modules[__name__]) +install_strategy_trend(app, _REPO_ROOT, app_module=sys.modules[__name__]) +``` + +--- + +## 静态资源与 URL + +- 四所页面仍通过 **`/static/...`** 访问共用脚本;`hub_bridge.install_instance_theme_static` 从 `lib/common/static/` 提供部分根级静态路由。 +- 各所目录下 **`static/`**(图标、上传图片等)仍为实例私有,未迁入 `lib/`。 +- 中控 `manual_trading_hub/hub.py` 通过 `_REPO_ROOT / "lib" / "common" / "static"` 挂载与四所共用的 badge、复盘 JS 等。 + +--- + +## 测试 + +在仓库根执行(需将根目录置于 Python 路径,或从根目录运行): + +```bash +cd /opt/crypto_monitor +python -m unittest discover -s tests -p "test_*.py" +``` + +测试文件内统一 `from lib.<子包>.<模块> import ...`。使用 `@patch` 时目标写完整模块路径,例如 `lib.hub.hub_calculator_lib._resolve_market`。 + +--- + +## 迁移脚本 + +一次性迁移由 `scripts/migrate_to_lib.py` 完成(移动文件 + 批量改写 import)。**不要在已迁移后的仓库上重复执行**。 + +--- + +## 后续可选整理 + +- 四所 `app.py` 体量接近,可逐步抽取公共 `exchange_app` 基座(改动面大,单独规划)。 +- `manual_trading_hub/okx_orders_lib.py` 为 agent 本地副本,可与 `lib/exchange/okx_orders_lib.py` 合并去重。 +- 可引入 `pyproject.toml` + `pip install -e .`,替代 `sys.path.insert`(长期维护更规范)。 + +--- + +## 相关文档 + +- [README.md](../README.md) — 总览与部署 +- [策略交易说明.md](../策略交易说明.md) +- [manual_trading_hub/使用说明.md](../manual_trading_hub/使用说明.md) diff --git a/lib/__init__.py b/lib/__init__.py new file mode 100644 index 0000000..54e157b --- /dev/null +++ b/lib/__init__.py @@ -0,0 +1 @@ +"""crypto_monitor shared libraries.""" diff --git a/lib/ai/__init__.py b/lib/ai/__init__.py new file mode 100644 index 0000000..ab164b5 --- /dev/null +++ b/lib/ai/__init__.py @@ -0,0 +1 @@ +"""Shared library package.""" diff --git a/ai_client.py b/lib/ai/ai_client.py similarity index 100% rename from ai_client.py rename to lib/ai/ai_client.py diff --git a/ai_review_lib.py b/lib/ai/ai_review_lib.py similarity index 96% rename from ai_review_lib.py rename to lib/ai/ai_review_lib.py index 0a8293e..21b51e8 100644 --- a/ai_review_lib.py +++ b/lib/ai/ai_review_lib.py @@ -1,180 +1,180 @@ -"""AI 日复盘 / 周复盘:附图收集与 journal 文本格式化(四所共用)。""" -from __future__ import annotations - -import os -import uuid -from typing import Any, Callable, List, Mapping, Optional, Sequence - -from journal_chart_lib import ( - JOURNAL_CHART_ANCHOR_CLOSE, - JOURNAL_CHART_DEFAULT_LIMIT, - JOURNAL_CHART_DEFAULT_TF1, - JOURNAL_CHART_DEFAULT_TF2, - normalize_chart_timeframe, -) - - -def _journal_nz(v: Any, default: str = "无") -> str: - if v is None: - return default - s = str(v).strip() - return s if s else default - - -def _row_get(row: Any, key: str, default: Any = None) -> Any: - """兼容 dict 与 sqlite3.Row(Row 无 .get 方法)。""" - if row is None: - return default - getter = getattr(row, "get", None) - if callable(getter): - return getter(key, default) - try: - keys = row.keys() if hasattr(row, "keys") else () - if key in keys: - return row[key] - except Exception: - pass - try: - return row[key] - except (KeyError, TypeError, IndexError): - return default - - -def journal_row_lines_for_ai( - idx: int, - row: Any, - *, - include_hold_duration: bool = True, -) -> str: - """把 journal 字段拼成给 AI 的文本;四所日复盘/周复盘共用。""" - lines = [ - ( - f"{idx}. {_journal_nz(_row_get(row, 'coin'))} {_journal_nz(_row_get(row, 'tf'))} " - f"| 盈亏:{_journal_nz(_row_get(row, 'pnl'))}U " - f"| 实际RR:{_journal_nz(_row_get(row, 'real_rr'))} " - f"| 预期RR:{_journal_nz(_row_get(row, 'expect_rr'))}" - ), - f" 开仓逻辑:{_journal_nz(_row_get(row, 'entry_reason'))}", - f" 平仓/离场(交易员自述):{_journal_nz(_row_get(row, 'exit_reason'))}", - ] - if include_hold_duration: - lines.append(f" 持仓时长:{_journal_nz(_row_get(row, 'hold_duration'))}") - ee_bits = [ - _journal_nz(_row_get(row, "early_exit")), - _journal_nz(_row_get(row, "early_exit_reason")), - _journal_nz(_row_get(row, "early_exit_trigger")), - _journal_nz(_row_get(row, "early_exit_note")), - ] - if any(x != "无" for x in ee_bits): - lines.append( - " 提前离场记录:" - f"{ee_bits[0]} | 原因:{ee_bits[1]} | 触发:{ee_bits[2]} | 备注:{ee_bits[3]}" - ) - mood_bits = f"心态标签:{_journal_nz(_row_get(row, 'mood_issues'))}" - mood_score = _row_get(row, "mood_score") - if mood_score is not None: - mood_bits += f" | 自评心态分:{mood_score}" - lines.append(f" {mood_bits}") - if _journal_nz(_row_get(row, "post_breakeven_stare")) != "无": - lines.append(f" 保本后盯盘:{_journal_nz(_row_get(row, 'post_breakeven_stare'))}") - if _journal_nz(_row_get(row, "new_trade_while_occupied")) != "无": - lines.append(f" 占用时新开仓:{_journal_nz(_row_get(row, 'new_trade_while_occupied'))}") - if _journal_nz(_row_get(row, "note")) != "无": - lines.append(f" 备注:{_journal_nz(_row_get(row, 'note'))}") - return "\n".join(lines) + "\n" - - -def collect_images_for_ai_review( - rows: Sequence, - upload_folder: str, - *, - build_chart_if_missing: Optional[Callable] = None, -) -> List[str]: - """ - 收集传给视觉模型的本地图片路径。 - - 优先 journal_entries.image 已存附图; - - 若无附图且提供 build_chart_if_missing,则临时生成 K 线图。 - """ - paths: List[str] = [] - seen = set() - upload_folder = os.path.abspath(upload_folder or "") - for row in rows or []: - candidate = None - try: - keys = row.keys() if hasattr(row, "keys") else [] - except Exception: - keys = [] - img = row["image"] if "image" in keys else None - if img: - candidate = os.path.join(upload_folder, str(img).strip()) - elif build_chart_if_missing: - try: - candidate = build_chart_if_missing(row) - except Exception: - candidate = None - if not candidate: - continue - candidate = os.path.abspath(candidate) - if os.path.isfile(candidate) and candidate not in seen: - seen.add(candidate) - paths.append(candidate) - return paths - - -def build_journal_ai_chart_path( - row, - upload_folder: str, - *, - order_chart_enabled: bool, - normalize_exchange_symbol_fn: Callable[[str], str], - generate_chart_fn: Callable, - local_datetime_to_ms_fn: Callable[[str], Optional[int]], - now_ts_ms_fn: Callable[[], int], -) -> Optional[str]: - """无已存附图时,按复盘记录开平仓时间临时生成 K 线图路径。""" - if not order_chart_enabled: - return None - try: - keys = row.keys() if hasattr(row, "keys") else [] - except Exception: - return None - coin = (row["coin"] if "coin" in keys else "") or "" - coin = str(coin).strip() - if not coin: - return None - try: - symbol = normalize_exchange_symbol_fn(coin) - except Exception: - return None - open_dt = row["open_datetime"] if "open_datetime" in keys else "" - close_dt = row["close_datetime"] if "close_datetime" in keys else "" - entry_ms = local_datetime_to_ms_fn(open_dt) - exit_ms = local_datetime_to_ms_fn(close_dt) - if not entry_ms: - return None - row_tf = row["tf"] if "tf" in keys else "" - tf1 = normalize_chart_timeframe(row_tf) or JOURNAL_CHART_DEFAULT_TF1 - tf2 = JOURNAL_CHART_DEFAULT_TF2 if tf1 != JOURNAL_CHART_DEFAULT_TF2 else "1h" - row_id = str(row["id"] if "id" in keys else "")[:8] or uuid.uuid4().hex[:8] - marker = { - "entry_ts_ms": entry_ms, - "exit_ts_ms": exit_ms, - "chart_anchor": JOURNAL_CHART_ANCHOR_CLOSE, - "now_ts_ms": int(now_ts_ms_fn()), - } - fname = f"ai_rev_{row_id}_{uuid.uuid4().hex[:6]}.png" - saved = generate_chart_fn( - symbol, - f"AI复盘 {coin}", - timeframes=[tf1, tf2], - limit=JOURNAL_CHART_DEFAULT_LIMIT, - out_dir=upload_folder, - filename=fname, - marker_payload=marker, - marker_timeframes={tf1, tf2}, - layout="vertical", - ) - if not saved: - return None - path = os.path.join(upload_folder, saved) - return path if os.path.isfile(path) else None +"""AI 日复盘 / 周复盘:附图收集与 journal 文本格式化(四所共用)。""" +from __future__ import annotations + +import os +import uuid +from typing import Any, Callable, List, Mapping, Optional, Sequence + +from lib.instance.journal_chart_lib import ( + JOURNAL_CHART_ANCHOR_CLOSE, + JOURNAL_CHART_DEFAULT_LIMIT, + JOURNAL_CHART_DEFAULT_TF1, + JOURNAL_CHART_DEFAULT_TF2, + normalize_chart_timeframe, +) + + +def _journal_nz(v: Any, default: str = "无") -> str: + if v is None: + return default + s = str(v).strip() + return s if s else default + + +def _row_get(row: Any, key: str, default: Any = None) -> Any: + """兼容 dict 与 sqlite3.Row(Row 无 .get 方法)。""" + if row is None: + return default + getter = getattr(row, "get", None) + if callable(getter): + return getter(key, default) + try: + keys = row.keys() if hasattr(row, "keys") else () + if key in keys: + return row[key] + except Exception: + pass + try: + return row[key] + except (KeyError, TypeError, IndexError): + return default + + +def journal_row_lines_for_ai( + idx: int, + row: Any, + *, + include_hold_duration: bool = True, +) -> str: + """把 journal 字段拼成给 AI 的文本;四所日复盘/周复盘共用。""" + lines = [ + ( + f"{idx}. {_journal_nz(_row_get(row, 'coin'))} {_journal_nz(_row_get(row, 'tf'))} " + f"| 盈亏:{_journal_nz(_row_get(row, 'pnl'))}U " + f"| 实际RR:{_journal_nz(_row_get(row, 'real_rr'))} " + f"| 预期RR:{_journal_nz(_row_get(row, 'expect_rr'))}" + ), + f" 开仓逻辑:{_journal_nz(_row_get(row, 'entry_reason'))}", + f" 平仓/离场(交易员自述):{_journal_nz(_row_get(row, 'exit_reason'))}", + ] + if include_hold_duration: + lines.append(f" 持仓时长:{_journal_nz(_row_get(row, 'hold_duration'))}") + ee_bits = [ + _journal_nz(_row_get(row, "early_exit")), + _journal_nz(_row_get(row, "early_exit_reason")), + _journal_nz(_row_get(row, "early_exit_trigger")), + _journal_nz(_row_get(row, "early_exit_note")), + ] + if any(x != "无" for x in ee_bits): + lines.append( + " 提前离场记录:" + f"{ee_bits[0]} | 原因:{ee_bits[1]} | 触发:{ee_bits[2]} | 备注:{ee_bits[3]}" + ) + mood_bits = f"心态标签:{_journal_nz(_row_get(row, 'mood_issues'))}" + mood_score = _row_get(row, "mood_score") + if mood_score is not None: + mood_bits += f" | 自评心态分:{mood_score}" + lines.append(f" {mood_bits}") + if _journal_nz(_row_get(row, "post_breakeven_stare")) != "无": + lines.append(f" 保本后盯盘:{_journal_nz(_row_get(row, 'post_breakeven_stare'))}") + if _journal_nz(_row_get(row, "new_trade_while_occupied")) != "无": + lines.append(f" 占用时新开仓:{_journal_nz(_row_get(row, 'new_trade_while_occupied'))}") + if _journal_nz(_row_get(row, "note")) != "无": + lines.append(f" 备注:{_journal_nz(_row_get(row, 'note'))}") + return "\n".join(lines) + "\n" + + +def collect_images_for_ai_review( + rows: Sequence, + upload_folder: str, + *, + build_chart_if_missing: Optional[Callable] = None, +) -> List[str]: + """ + 收集传给视觉模型的本地图片路径。 + - 优先 journal_entries.image 已存附图; + - 若无附图且提供 build_chart_if_missing,则临时生成 K 线图。 + """ + paths: List[str] = [] + seen = set() + upload_folder = os.path.abspath(upload_folder or "") + for row in rows or []: + candidate = None + try: + keys = row.keys() if hasattr(row, "keys") else [] + except Exception: + keys = [] + img = row["image"] if "image" in keys else None + if img: + candidate = os.path.join(upload_folder, str(img).strip()) + elif build_chart_if_missing: + try: + candidate = build_chart_if_missing(row) + except Exception: + candidate = None + if not candidate: + continue + candidate = os.path.abspath(candidate) + if os.path.isfile(candidate) and candidate not in seen: + seen.add(candidate) + paths.append(candidate) + return paths + + +def build_journal_ai_chart_path( + row, + upload_folder: str, + *, + order_chart_enabled: bool, + normalize_exchange_symbol_fn: Callable[[str], str], + generate_chart_fn: Callable, + local_datetime_to_ms_fn: Callable[[str], Optional[int]], + now_ts_ms_fn: Callable[[], int], +) -> Optional[str]: + """无已存附图时,按复盘记录开平仓时间临时生成 K 线图路径。""" + if not order_chart_enabled: + return None + try: + keys = row.keys() if hasattr(row, "keys") else [] + except Exception: + return None + coin = (row["coin"] if "coin" in keys else "") or "" + coin = str(coin).strip() + if not coin: + return None + try: + symbol = normalize_exchange_symbol_fn(coin) + except Exception: + return None + open_dt = row["open_datetime"] if "open_datetime" in keys else "" + close_dt = row["close_datetime"] if "close_datetime" in keys else "" + entry_ms = local_datetime_to_ms_fn(open_dt) + exit_ms = local_datetime_to_ms_fn(close_dt) + if not entry_ms: + return None + row_tf = row["tf"] if "tf" in keys else "" + tf1 = normalize_chart_timeframe(row_tf) or JOURNAL_CHART_DEFAULT_TF1 + tf2 = JOURNAL_CHART_DEFAULT_TF2 if tf1 != JOURNAL_CHART_DEFAULT_TF2 else "1h" + row_id = str(row["id"] if "id" in keys else "")[:8] or uuid.uuid4().hex[:8] + marker = { + "entry_ts_ms": entry_ms, + "exit_ts_ms": exit_ms, + "chart_anchor": JOURNAL_CHART_ANCHOR_CLOSE, + "now_ts_ms": int(now_ts_ms_fn()), + } + fname = f"ai_rev_{row_id}_{uuid.uuid4().hex[:6]}.png" + saved = generate_chart_fn( + symbol, + f"AI复盘 {coin}", + timeframes=[tf1, tf2], + limit=JOURNAL_CHART_DEFAULT_LIMIT, + out_dir=upload_folder, + filename=fname, + marker_payload=marker, + marker_timeframes={tf1, tf2}, + layout="vertical", + ) + if not saved: + return None + path = os.path.join(upload_folder, saved) + return path if os.path.isfile(path) else None diff --git a/lib/common/__init__.py b/lib/common/__init__.py new file mode 100644 index 0000000..ab164b5 --- /dev/null +++ b/lib/common/__init__.py @@ -0,0 +1 @@ +"""Shared library package.""" diff --git a/auto_transfer_daily_lib.py b/lib/common/auto_transfer_daily_lib.py similarity index 100% rename from auto_transfer_daily_lib.py rename to lib/common/auto_transfer_daily_lib.py diff --git a/form_submit_lib.py b/lib/common/form_submit_lib.py similarity index 100% rename from form_submit_lib.py rename to lib/common/form_submit_lib.py diff --git a/history_window_lib.py b/lib/common/history_window_lib.py similarity index 100% rename from history_window_lib.py rename to lib/common/history_window_lib.py diff --git a/static/account_risk_badge.css b/lib/common/static/account_risk_badge.css similarity index 100% rename from static/account_risk_badge.css rename to lib/common/static/account_risk_badge.css diff --git a/static/account_risk_badge.js b/lib/common/static/account_risk_badge.js similarity index 100% rename from static/account_risk_badge.js rename to lib/common/static/account_risk_badge.js diff --git a/static/ai_review_render.js b/lib/common/static/ai_review_render.js similarity index 100% rename from static/ai_review_render.js rename to lib/common/static/ai_review_render.js diff --git a/static/focus_chart_page.css b/lib/common/static/focus_chart_page.css similarity index 100% rename from static/focus_chart_page.css rename to lib/common/static/focus_chart_page.css diff --git a/static/focus_chart_page.js b/lib/common/static/focus_chart_page.js similarity index 100% rename from static/focus_chart_page.js rename to lib/common/static/focus_chart_page.js diff --git a/static/form_submit_guard.js b/lib/common/static/form_submit_guard.js similarity index 100% rename from static/form_submit_guard.js rename to lib/common/static/form_submit_guard.js diff --git a/static/instance_embed.js b/lib/common/static/instance_embed.js similarity index 100% rename from static/instance_embed.js rename to lib/common/static/instance_embed.js diff --git a/static/instance_page.css b/lib/common/static/instance_page.css similarity index 100% rename from static/instance_page.css rename to lib/common/static/instance_page.css diff --git a/static/instance_records_mobile.js b/lib/common/static/instance_records_mobile.js similarity index 100% rename from static/instance_records_mobile.js rename to lib/common/static/instance_records_mobile.js diff --git a/static/instance_theme.css b/lib/common/static/instance_theme.css similarity index 100% rename from static/instance_theme.css rename to lib/common/static/instance_theme.css diff --git a/static/instance_theme.js b/lib/common/static/instance_theme.js similarity index 100% rename from static/instance_theme.js rename to lib/common/static/instance_theme.js diff --git a/static/instance_theme_early.css b/lib/common/static/instance_theme_early.css similarity index 100% rename from static/instance_theme_early.css rename to lib/common/static/instance_theme_early.css diff --git a/static/instance_ui.js b/lib/common/static/instance_ui.js similarity index 100% rename from static/instance_ui.js rename to lib/common/static/instance_ui.js diff --git a/static/key_monitor_form.js b/lib/common/static/key_monitor_form.js similarity index 100% rename from static/key_monitor_form.js rename to lib/common/static/key_monitor_form.js diff --git a/static/manual_order_rr_preview.js b/lib/common/static/manual_order_rr_preview.js similarity index 100% rename from static/manual_order_rr_preview.js rename to lib/common/static/manual_order_rr_preview.js diff --git a/static/strategy_roll.js b/lib/common/static/strategy_roll.js similarity index 100% rename from static/strategy_roll.js rename to lib/common/static/strategy_roll.js diff --git a/static/time_close_ui.js b/lib/common/static/time_close_ui.js similarity index 100% rename from static/time_close_ui.js rename to lib/common/static/time_close_ui.js diff --git a/static/trade_stats_calendar.css b/lib/common/static/trade_stats_calendar.css similarity index 100% rename from static/trade_stats_calendar.css rename to lib/common/static/trade_stats_calendar.css diff --git a/static/trade_stats_calendar.js b/lib/common/static/trade_stats_calendar.js similarity index 100% rename from static/trade_stats_calendar.js rename to lib/common/static/trade_stats_calendar.js diff --git a/wechat_notify_lib.py b/lib/common/wechat_notify_lib.py similarity index 100% rename from wechat_notify_lib.py rename to lib/common/wechat_notify_lib.py diff --git a/lib/exchange/__init__.py b/lib/exchange/__init__.py new file mode 100644 index 0000000..ab164b5 --- /dev/null +++ b/lib/exchange/__init__.py @@ -0,0 +1 @@ +"""Shared library package.""" diff --git a/gate_position_history_lib.py b/lib/exchange/gate_position_history_lib.py similarity index 100% rename from gate_position_history_lib.py rename to lib/exchange/gate_position_history_lib.py diff --git a/gate_transfer_lib.py b/lib/exchange/gate_transfer_lib.py similarity index 100% rename from gate_transfer_lib.py rename to lib/exchange/gate_transfer_lib.py diff --git a/okx_orders_lib.py b/lib/exchange/okx_orders_lib.py similarity index 100% rename from okx_orders_lib.py rename to lib/exchange/okx_orders_lib.py diff --git a/lib/hub/__init__.py b/lib/hub/__init__.py new file mode 100644 index 0000000..ab164b5 --- /dev/null +++ b/lib/hub/__init__.py @@ -0,0 +1 @@ +"""Shared library package.""" diff --git a/hub_auth.py b/lib/hub/hub_auth.py similarity index 92% rename from hub_auth.py rename to lib/hub/hub_auth.py index cfa867d..db2e586 100644 --- a/hub_auth.py +++ b/lib/hub/hub_auth.py @@ -1,36 +1,36 @@ -"""中控调用实例 API 时的鉴权(Flask request 头 X-Hub-Token)。SSO 见 hub_sso.py。""" -from __future__ import annotations - -import os - -from hub_sso import ( - HUB_SSO_TTL_SEC, - hub_bridge_token, - mint_hub_sso_token, - safe_next_path, - verify_hub_sso_token, -) - -__all__ = [ - "HUB_SSO_TTL_SEC", - "hub_bridge_token", - "mint_hub_sso_token", - "safe_next_path", - "verify_hub_sso_token", - "request_allowed", -] - - -def request_allowed(session_logged_in: bool, auth_disabled: bool) -> bool: - if auth_disabled or session_logged_in: - return True - tok = hub_bridge_token() - if not tok: - return False - try: - from flask import request - except ImportError: - return False - if request.headers.get("X-Hub-Token") == tok: - return True - return False +"""中控调用实例 API 时的鉴权(Flask request 头 X-Hub-Token)。SSO 见 hub_sso.py。""" +from __future__ import annotations + +import os + +from lib.hub.hub_sso import ( + HUB_SSO_TTL_SEC, + hub_bridge_token, + mint_hub_sso_token, + safe_next_path, + verify_hub_sso_token, +) + +__all__ = [ + "HUB_SSO_TTL_SEC", + "hub_bridge_token", + "mint_hub_sso_token", + "safe_next_path", + "verify_hub_sso_token", + "request_allowed", +] + + +def request_allowed(session_logged_in: bool, auth_disabled: bool) -> bool: + if auth_disabled or session_logged_in: + return True + tok = hub_bridge_token() + if not tok: + return False + try: + from flask import request + except ImportError: + return False + if request.headers.get("X-Hub-Token") == tok: + return True + return False diff --git a/hub_bridge.py b/lib/hub/hub_bridge.py similarity index 94% rename from hub_bridge.py rename to lib/hub/hub_bridge.py index 49f1e57..2154655 100644 --- a/hub_bridge.py +++ b/lib/hub/hub_bridge.py @@ -1,1025 +1,1027 @@ -""" -各 crypto_monitor_* 注册 /api/hub/* JSON 接口,供 manual_trading_hub 调用。 -实例末尾:app.config["HUB_CTX"] = {...}; register_hub_routes(app) -""" - -from __future__ import annotations - -import json -import time -from functools import wraps - -from flask import ( - current_app, - flash, - get_flashed_messages, - jsonify, - redirect, - request, - session, -) - -from hub_auth import request_allowed -from hub_sso import ( - mint_hub_embed_bootstrap, - safe_next_path, - verify_hub_embed_bootstrap, - verify_hub_sso_token, -) - - -def _merge_query_into_path(path: str, **params: str) -> str: - from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit - - split = urlsplit(path or "/") - q = list(parse_qsl(split.query, keep_blank_values=True)) - keys = {k for k, _ in q} - for k, v in params.items(): - if not v or k in keys: - continue - q.append((k, str(v))) - return urlunsplit((split.scheme, split.netloc, split.path, urlencode(q), split.fragment)) - - -def install_instance_theme_static(app) -> None: - """仓库根 static/instance_theme.* 供四所页面共用。""" - import os - - from flask import Response, send_file - - repo_static = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static") - assets = { - "instance_theme.js": "application/javascript; charset=utf-8", - "instance_theme_early.css": "text/css; charset=utf-8", - "instance_theme.css": "text/css; charset=utf-8", - "account_risk_badge.css": "text/css; charset=utf-8", - "account_risk_badge.js": "application/javascript; charset=utf-8", - "instance_ui.js": "application/javascript; charset=utf-8", - "instance_records_mobile.js": "application/javascript; charset=utf-8", - "ai_review_render.js": "application/javascript; charset=utf-8", - "form_submit_guard.js": "application/javascript; charset=utf-8", - "key_monitor_form.js": "application/javascript; charset=utf-8", - "time_close_ui.js": "application/javascript; charset=utf-8", - "manual_order_rr_preview.js": "application/javascript; charset=utf-8", - "strategy_roll.js": "application/javascript; charset=utf-8", - "instance_page.css": "text/css; charset=utf-8", - "instance_embed.js": "application/javascript; charset=utf-8", - "focus_chart_page.js": "application/javascript; charset=utf-8", - "focus_chart_page.css": "text/css; charset=utf-8", - "trade_stats_calendar.js": "application/javascript; charset=utf-8", - "trade_stats_calendar.css": "text/css; charset=utf-8", - } - - for name, mime in assets.items(): - path = os.path.join(repo_static, name) - - def _view(p=path, m=mime): - if not os.path.isfile(p): - return Response("not found", status=404, mimetype="text/plain; charset=utf-8") - return send_file(p, mimetype=m) - - app.add_url_rule( - f"/static/{name}", - endpoint=f"repo_static_{name.replace('.', '_')}", - view_func=_view, - ) - - -def register_trade_stats_calendar_route( - app, - *, - login_required_fn, - load_pnls_fn, - row_matches_segment_fn, - reset_hour: int, - get_db_fn=None, -): - """四所统计分析页:按月返回各交易日盈亏/笔数。""" - from flask import jsonify, request - - from trade_stats_calendar_lib import build_trade_stats_calendar - - @app.route("/api/stats/calendar") - @login_required_fn - def api_stats_calendar(): - year = request.args.get("year", type=int) - month = request.args.get("month", type=int) - segment = (request.args.get("segment") or "all").strip() or "all" - if not year or not month: - from datetime import datetime - - now = datetime.now() - year = year or now.year - month = month or now.month - get_db = get_db_fn or (app.config.get("HUB_CTX") or {}).get("get_db") - if not get_db: - return jsonify({"ok": False, "msg": "未配置数据库"}), 500 - conn = get_db() - try: - pnls = load_pnls_fn(conn) - finally: - conn.close() - try: - payload = build_trade_stats_calendar( - pnls, - year, - month, - segment, - row_matches_segment_fn, - reset_hour=int(reset_hour), - ) - except ValueError as exc: - return jsonify({"ok": False, "msg": str(exc)}), 400 - return jsonify({"ok": True, **payload}) - - -def _hub_auth_required(f): - @wraps(f) - def wrapped(*args, **kwargs): - from flask import current_app as cap - - auth_disabled = bool(cap.config.get("HUB_AUTH_DISABLED")) - if not request_allowed(bool(session.get("logged_in")), auth_disabled): - return jsonify({"ok": False, "msg": "未授权(登录或 HUB_BRIDGE_TOKEN)"}), 401 - return f(*args, **kwargs) - - return wrapped - - -def _ctx(): - return current_app.config.get("HUB_CTX") or {} - - -def _row_to_dict(row): - fn = _ctx().get("row_to_dict") - if fn and row is not None: - return fn(row) - return dict(row) if row is not None else {} - - -def build_hub_monitor_payload( - *, - keys, - orders, - trends, - rolls, - enrich=None, - risk_status=None, -) -> dict: - """合并 enrich 增量字段;enrich 只返回 trends 等局部时不得丢掉 keys/orders。""" - payload = { - "ok": True, - "keys": keys, - "orders": orders, - "trends": trends, - "rolls": rolls, - "key_prices": [], - } - if isinstance(risk_status, dict): - payload["risk_status"] = risk_status - if callable(enrich): - extra = enrich(keys=keys, orders=orders, trends=trends, rolls=rolls) - if isinstance(extra, dict): - payload.update(extra) - return payload - - -_FAIL_HINTS = ( - "失败", - "错误", - "拒绝", - "无效", - "缺少", - "无法", - "过期", - "未达", - "不能为空", - "已有", - "不允许", - "异常", -) - - -def _invoke_view(view_name: str, path: str, form=None) -> dict: - views = _ctx().get("views") or {} - view = views.get(view_name) - if not view: - return {"ok": False, "messages": [f"未配置视图 {view_name}"]} - data = form if form is not None else request.form - if hasattr(data, "items") and not isinstance(data, dict): - data = {k: v for k, v in data.items()} - with current_app.test_request_context(path, method="POST", data=data): - session["logged_in"] = True - try: - view() - except Exception as e: - return {"ok": False, "messages": [str(e)]} - try: - msgs = [str(x) for x in get_flashed_messages()] - except Exception as e: - return {"ok": False, "messages": [f"读取提示信息失败: {e}"]} - ok = True - for m in msgs: - if any(k in m for k in _FAIL_HINTS): - ok = False - break - return {"ok": ok, "messages": msgs} - - -def _invoke_view_get(view_name: str, path: str) -> dict: - views = _ctx().get("views") or {} - view = views.get(view_name) - if not view: - return {"ok": False, "messages": [f"未配置视图 {view_name}"]} - with current_app.test_request_context(path, method="GET"): - session["logged_in"] = True - try: - view() - except Exception as e: - return {"ok": False, "messages": [str(e)]} - try: - msgs = [str(x) for x in get_flashed_messages()] - except Exception as e: - return {"ok": False, "messages": [f"读取提示信息失败: {e}"]} - ok = True - for m in msgs: - if any(k in m for k in _FAIL_HINTS): - ok = False - break - return {"ok": ok, "messages": msgs} - - -def _hub_json(view_name: str, path: str, form=None): - try: - return jsonify(_invoke_view(view_name, path, form=form)) - except Exception as e: - return jsonify({"ok": False, "messages": [str(e)]}) - - -def _embed_login_dest(next_path: str) -> str: - """embed=1 时把 /trade 等映射到 /embed?tab=…""" - ht = (request.args.get("hub_theme") or "").strip().lower() - hub_theme = ht if ht in ("light", "dark") else None - if request.args.get("embed", "").strip().lower() in ("1", "true", "yes", "on"): - from instance_embed_lib import rewrite_embed_dest - - return rewrite_embed_dest(next_path, hub_theme=hub_theme) - if hub_theme: - return _merge_query_into_path(next_path, hub_theme=hub_theme) - return next_path - - -def install_on_app( - app, - *, - exchange: str, - capabilities: list, - has_trend: bool, - get_db, - row_to_dict, - meta_fn, - views: dict, - ohlcv_fn=None, - account_fn=None, - volume_rank_fn=None, - market_fn=None, - reconcile_hub_flat_fn=None, - risk_status_fn=None, - user_close_fn=None, - render_main_page_fn=None, - login_required_fn=None, -): - app.config["HUB_CTX"] = { - "exchange": exchange, - "capabilities": list(capabilities), - "has_trend": bool(has_trend), - "get_db": get_db, - "row_to_dict": row_to_dict, - "meta_fn": meta_fn, - "account_fn": account_fn, - "views": views, - "ohlcv_fn": ohlcv_fn, - "volume_rank_fn": volume_rank_fn, - "market_fn": market_fn, - "reconcile_hub_flat_fn": reconcile_hub_flat_fn, - "risk_status_fn": risk_status_fn, - "user_close_fn": user_close_fn, - } - install_hub_embed_headers(app) - configure_hub_embed_session(app) - install_instance_theme_static(app) - register_hub_routes(app) - if render_main_page_fn and login_required_fn: - import os - - from instance_embed_lib import attach_embed_templates, register_embed_routes - - repo_root = os.path.dirname(os.path.abspath(__file__)) - attach_embed_templates(app, repo_root) - register_embed_routes(app, login_required_fn, render_main_page_fn) - - -def configure_hub_embed_session(app): - """HTTPS iframe 内嵌须 SameSite=None + Secure;hub-sso / hub-embed-auth 自动启用。""" - import os - - allowed = (os.getenv("APP_ALLOW_HUB_EMBED") or "true").strip().lower() in ( - "1", - "true", - "yes", - "on", - ) - if not allowed: - return - - secure_env = (os.getenv("APP_COOKIE_SECURE") or "auto").strip().lower() - if secure_env in ("1", "true", "yes", "on"): - app.config.update( - SESSION_COOKIE_SECURE=True, - SESSION_COOKIE_SAMESITE="None", - SESSION_COOKIE_HTTPONLY=True, - ) - return - - @app.before_request - def _hub_embed_session_cookie(): - if request.path not in ("/hub-sso", "/hub-embed-auth"): - return - embed = (request.args.get("embed") or "").strip().lower() in ( - "1", - "true", - "yes", - "on", - ) - in_iframe = (request.headers.get("Sec-Fetch-Dest") or "").lower() == "iframe" - if not embed and not in_iframe: - return - if not request.is_secure: - return - app.config["SESSION_COOKIE_SECURE"] = True - app.config["SESSION_COOKIE_SAMESITE"] = "None" - app.config["SESSION_COOKIE_HTTPONLY"] = True - - -def _sso_wants_embed_auth() -> bool: - embed = (request.args.get("embed") or "").strip().lower() in ( - "1", - "true", - "yes", - "on", - ) - in_iframe = (request.headers.get("Sec-Fetch-Dest") or "").lower() == "iframe" - return bool(embed or in_iframe) - - -def install_hub_embed_headers(app): - """允许复盘中控 iframe 内嵌打开本实例(须与 hub 的 HUB_EMBED_ORIGINS 或域名一致)。""" - import os - - allowed = (os.getenv("APP_ALLOW_HUB_EMBED") or "true").strip().lower() in ( - "1", - "true", - "yes", - "on", - ) - if not allowed: - return - origins = ( - (os.getenv("HUB_EMBED_PARENT_ORIGINS") or os.getenv("HUB_EMBED_ORIGINS") or "*") - .strip() - ) - - @app.after_request - def _hub_embed_frame_headers(response): - if origins == "*": - response.headers["Content-Security-Policy"] = "frame-ancestors *" - else: - response.headers["Content-Security-Policy"] = ( - f"frame-ancestors 'self' {origins}" - ) - return response - - -def register_hub_routes(app): - auth_disabled = False - try: - import os - - auth_disabled = os.getenv("APP_AUTH_DISABLED", "false").lower() in ( - "1", - "true", - "yes", - "on", - ) - except Exception: - pass - app.config.setdefault("HUB_AUTH_DISABLED", auth_disabled) - - @app.route("/api/hub/ping") - @_hub_auth_required - def api_hub_ping(): - c = _ctx() - return jsonify( - { - "ok": True, - "exchange": c.get("exchange"), - "capabilities": c.get("capabilities") or [], - } - ) - - @app.route("/api/hub/meta") - @_hub_auth_required - def api_hub_meta(): - c = _ctx() - meta_fn = c.get("meta_fn") - meta = meta_fn() if callable(meta_fn) else {} - return jsonify({"ok": True, "meta": meta}) - - @app.route("/api/hub/account") - @_hub_auth_required - def api_hub_account(): - """中控 AI:资金账户 / 交易账户余额(无需浏览器登录)。""" - fn = _ctx().get("account_fn") - if not callable(fn): - return jsonify({"ok": False, "msg": "未配置 account_fn"}), 501 - try: - data = fn() - if not isinstance(data, dict): - data = {} - return jsonify({"ok": True, **data}) - except Exception as e: - return jsonify({"ok": False, "msg": str(e)}), 500 - - @app.route("/api/account_risk_status") - @_hub_auth_required - def api_account_risk_status(): - c = _ctx() - get_db = c.get("get_db") - risk_fn = c.get("risk_status_fn") - if not callable(get_db) or not callable(risk_fn): - return jsonify({"ok": False, "msg": "未配置风控"}), 501 - conn = get_db() - try: - payload = risk_fn(conn) - return jsonify({"ok": True, **(payload if isinstance(payload, dict) else {})}) - except Exception as e: - return jsonify({"ok": False, "msg": str(e)}), 500 - finally: - conn.close() - - @app.route("/api/hub/account-risk/user-close", methods=["POST"]) - @_hub_auth_required - def api_hub_account_risk_user_close(): - """中控/实例:登记用户主动平仓(计入冷静期与日冻结)。""" - c = _ctx() - get_db = c.get("get_db") - user_close_fn = c.get("user_close_fn") - if not callable(get_db) or not callable(user_close_fn): - return jsonify({"ok": False, "msg": "未配置 user_close_fn"}), 501 - body = request.get_json(silent=True) or {} - source = (body.get("source") or request.form.get("source") or "").strip() - try: - count = max(0, int(body.get("count") if body.get("count") is not None else 1)) - except (TypeError, ValueError): - count = 1 - trade_record_id = body.get("trade_record_id") - closed_at_ms = body.get("closed_at_ms") - if count <= 0: - return jsonify({"ok": True, "skipped": True, "count": 0}) - conn = get_db() - try: - user_close_fn( - conn, - source=source, - count=count, - trade_record_id=trade_record_id, - closed_at_ms=closed_at_ms, - ) - conn.commit() - return jsonify({"ok": True, "count": count, "source": source}) - except Exception as e: - return jsonify({"ok": False, "msg": str(e)}), 500 - finally: - conn.close() - - @app.route("/api/hub/monitor") - @_hub_auth_required - def api_hub_monitor(): - c = _ctx() - get_db = c.get("get_db") - if not get_db: - return jsonify({"ok": False, "msg": "HUB_CTX 缺少 get_db"}), 500 - conn = get_db() - keys = [] - for row in conn.execute("SELECT * FROM key_monitors ORDER BY id DESC").fetchall(): - keys.append(_row_to_dict(row)) - orders = [] - for row in conn.execute( - "SELECT * FROM order_monitors WHERE status='active' ORDER BY id DESC" - ).fetchall(): - od = _row_to_dict(row) - try: - from strategy_trade_labels import apply_order_monitor_source_labels - - od = apply_order_monitor_source_labels(od) - except Exception: - pass - orders.append(od) - trends = [] - if c.get("has_trend"): - for row in conn.execute( - "SELECT * FROM trend_pullback_plans WHERE status='active' ORDER BY id DESC" - ).fetchall(): - trends.append(_row_to_dict(row)) - rolls = [] - try: - for row in conn.execute( - """SELECT g.* FROM roll_groups g - INNER JOIN order_monitors m ON m.id = g.order_monitor_id AND m.status='active' - WHERE g.status='active' ORDER BY g.id DESC""" - ).fetchall(): - rolls.append(_row_to_dict(row)) - except Exception: - pass - risk_status = None - risk_fn = c.get("risk_status_fn") - if callable(risk_fn): - try: - risk_status = risk_fn(conn) - except Exception: - risk_status = None - conn.close() - enrich = c.get("enrich_monitor") - if callable(enrich): - try: - return jsonify( - build_hub_monitor_payload( - keys=keys, - orders=orders, - trends=trends, - rolls=rolls, - enrich=enrich, - risk_status=risk_status, - ) - ) - except Exception as e: - return jsonify({"ok": False, "msg": str(e)}), 500 - return jsonify( - build_hub_monitor_payload( - keys=keys, - orders=orders, - trends=trends, - rolls=rolls, - risk_status=risk_status, - ) - ) - - @app.route("/api/hub/trades/archive") - @_hub_auth_required - def api_hub_trades_archive(): - """中控币种档案:近 N 天已平仓记录。""" - from hub_trades_lib import fetch_trades_for_archive, summarize_trades - - c = _ctx() - get_db = c.get("get_db") - if not get_db: - return jsonify({"ok": False, "msg": "HUB_CTX 缺少 get_db"}), 500 - try: - days = int(request.args.get("days") or "365") - except ValueError: - days = 365 - try: - limit = int(request.args.get("limit") or "2000") - except ValueError: - limit = 2000 - try: - import os - - reset_hour = int(os.getenv("TRADING_DAY_RESET_HOUR", "8") or "8") - except ValueError: - reset_hour = 8 - conn = get_db() - try: - trades = fetch_trades_for_archive( - conn, - exchange_key=str(c.get("exchange") or ""), - days=days, - row_to_dict_fn=c.get("row_to_dict"), - reset_hour=reset_hour, - limit=limit, - ) - finally: - conn.close() - stats = summarize_trades(trades) - return jsonify( - { - "ok": True, - "days": max(1, min(days, 3650)), - "trading_day_reset_hour": reset_hour, - "trades": trades, - "stats": stats, - } - ) - - @app.route("/api/hub/trades/today") - @_hub_auth_required - def api_hub_trades_today(): - """中控 AI:当日已平仓记录(按实例交易日)。""" - from hub_trades_lib import ( - current_trading_day, - fetch_trades_for_trading_day, - summarize_trades, - ) - - c = _ctx() - get_db = c.get("get_db") - if not get_db: - return jsonify({"ok": False, "msg": "HUB_CTX 缺少 get_db"}), 500 - day_arg = (request.args.get("trading_day") or request.args.get("date") or "").strip()[:10] - try: - import os - - reset_hour = int(os.getenv("TRADING_DAY_RESET_HOUR", "8") or "8") - except ValueError: - reset_hour = 8 - trading_day = day_arg or current_trading_day(reset_hour=reset_hour) - conn = get_db() - try: - trades = fetch_trades_for_trading_day( - conn, - trading_day, - row_to_dict_fn=c.get("row_to_dict"), - reset_hour=reset_hour, - ) - finally: - conn.close() - stats = summarize_trades(trades) - return jsonify( - { - "ok": True, - "trading_day": trading_day, - "trading_day_reset_hour": reset_hour, - "trades": trades, - "stats": stats, - } - ) - - @app.route("/api/hub/volume-rank") - @_hub_auth_required - def api_hub_volume_rank(): - fn = _ctx().get("volume_rank_fn") - if not callable(fn): - return jsonify({"ok": False, "msg": "该实例未配置成交量排名接口"}), 501 - top_raw = (request.args.get("top") or "").strip() - top_n = 20 - if top_raw.isdigit(): - top_n = int(top_raw) - try: - result = fn(top_n=top_n) - if isinstance(result, dict): - return jsonify(result) - return jsonify({"ok": False, "msg": "成交量排名返回格式无效"}), 500 - except Exception as e: - return jsonify({"ok": False, "msg": str(e)}), 500 - - @app.route("/api/hub/market") - @_hub_auth_required - def api_hub_market(): - fn = _ctx().get("market_fn") - if not callable(fn): - return jsonify({"ok": False, "msg": "该实例未配置合约信息接口"}), 501 - base = (request.args.get("base") or request.args.get("symbol") or "").strip() - try: - result = fn(base=base) - if isinstance(result, dict): - return jsonify(result) - return jsonify({"ok": False, "msg": "合约信息返回格式无效"}), 500 - except Exception as e: - return jsonify({"ok": False, "msg": str(e)}), 500 - - @app.route("/api/hub/ohlcv") - @_hub_auth_required - def api_hub_ohlcv(): - fn = _ctx().get("ohlcv_fn") - if not callable(fn): - return jsonify({"ok": False, "msg": "该实例未配置 OHLCV 接口"}), 501 - symbol = (request.args.get("symbol") or "").strip() - timeframe = (request.args.get("timeframe") or "5m").strip() - since_raw = (request.args.get("since_ms") or "").strip() - limit_raw = (request.args.get("limit") or "").strip() - since_ms = None - if since_raw.isdigit(): - since_ms = int(since_raw) - limit = 500 - if limit_raw.isdigit(): - limit = int(limit_raw) - try: - result = fn(symbol=symbol, timeframe=timeframe, since_ms=since_ms, limit=limit) - if isinstance(result, dict): - return jsonify(result) - return jsonify({"ok": False, "msg": "OHLCV 返回格式无效"}), 500 - except Exception as e: - return jsonify({"ok": False, "msg": str(e)}), 500 - - @app.route("/api/hub/add_order", methods=["POST"]) - @_hub_auth_required - def api_hub_add_order(): - return _hub_json("add_order", "/add_order") - - @app.route("/api/hub/add_key", methods=["POST"]) - @_hub_auth_required - def api_hub_add_key(): - return _hub_json("add_key", "/add_key") - - @app.route("/api/hub/trend/preview", methods=["POST"]) - @_hub_auth_required - def api_hub_trend_preview(): - if not _ctx().get("has_trend"): - return jsonify({"ok": False, "msg": "该实例无趋势回调"}), 400 - data = _invoke_view("preview_trend_pullback", "/trade") - pid = _latest_preview_id() - preview = _fetch_preview(pid) if pid else None - return jsonify( - { - "ok": bool(data.get("ok")), - "messages": data.get("messages") or [], - "preview_id": pid, - "preview": preview, - } - ) - - @app.route("/api/hub/trend/execute", methods=["POST"]) - @_hub_auth_required - def api_hub_trend_execute(): - if not _ctx().get("has_trend"): - return jsonify({"ok": False, "msg": "该实例无趋势回调"}), 400 - pid = (request.form.get("preview_id") or "").strip() - if not pid: - body = request.get_json(silent=True) or {} - pid = str(body.get("preview_id") or "").strip() - form = {"preview_id": pid} if pid else {} - return jsonify(_invoke_view("execute_trend_pullback", "/trade", form=form)) - - @app.route("/api/hub/trend/preview/") - @_hub_auth_required - def api_hub_trend_preview_get(pid): - if not _ctx().get("has_trend"): - return jsonify({"ok": False, "msg": "该实例无趋势回调"}), 400 - preview = _fetch_preview(pid) - if not preview: - return jsonify({"ok": False, "msg": "预览不存在或已过期"}), 404 - return jsonify({"ok": True, "preview": preview}) - - @app.route("/api/hub/trend/stop/", methods=["POST"]) - @_hub_auth_required - def api_hub_trend_stop(pid): - if not _ctx().get("has_trend"): - return jsonify({"ok": False, "msg": "该实例无趋势回调"}), 400 - return jsonify(_invoke_view_get("stop_trend_pullback", f"/stop_trend_pullback/{pid}")) - - @app.route("/api/hub/order/sync-flat", methods=["POST"]) - @_hub_auth_required - def api_hub_order_sync_flat(): - """中控市价全平后:同步 order_monitors 并读 Gate 平仓历史写交易记录。""" - fn = _ctx().get("reconcile_hub_flat_fn") - if not callable(fn): - return jsonify({"ok": False, "msg": "该实例未配置 order sync-flat"}), 400 - body = request.get_json(silent=True) or {} - symbol = (body.get("symbol") or request.form.get("symbol") or "").strip() - side = ( - body.get("side") - or body.get("direction") - or request.form.get("side") - or "" - ).strip().lower() - if not symbol: - return jsonify({"ok": False, "msg": "symbol 不能为空"}), 400 - if side not in ("long", "short"): - return jsonify({"ok": False, "msg": "side 须为 long 或 short"}), 400 - get_db = _ctx().get("get_db") - if not callable(get_db): - return jsonify({"ok": False, "msg": "HUB_CTX 缺少 get_db"}), 500 - conn = get_db() - try: - out = fn(conn, symbol, side) - if not isinstance(out, dict): - out = {"ok": True, "synced": int(out or 0)} - conn.commit() - return jsonify(out) - except Exception as e: - return jsonify({"ok": False, "msg": str(e)}), 500 - finally: - conn.close() - - @app.route("/api/hub/trend/sync-flat", methods=["POST"]) - @_hub_auth_required - def api_hub_trend_sync_flat(): - """中控市价全平后:结束仍 active 的同币种同向趋势计划。""" - if not _ctx().get("has_trend"): - return jsonify({"ok": False, "msg": "该实例无趋势回调"}), 400 - body = request.get_json(silent=True) or {} - symbol = (body.get("symbol") or request.form.get("symbol") or "").strip() - side = ( - body.get("side") - or body.get("direction") - or request.form.get("side") - or "" - ).strip().lower() - if not symbol: - return jsonify({"ok": False, "msg": "symbol 不能为空"}), 400 - if side not in ("long", "short"): - return jsonify({"ok": False, "msg": "side 须为 long 或 short"}), 400 - cfg = current_app.extensions.get("strategy_trend_cfg") - get_db = _ctx().get("get_db") - if not cfg or not callable(get_db): - return jsonify({"ok": False, "msg": "趋势配置未就绪"}), 500 - from strategy_trend_register import sync_trend_plans_after_external_close - - conn = get_db() - try: - return jsonify(sync_trend_plans_after_external_close(cfg, conn, symbol, side)) - except Exception as e: - return jsonify({"ok": False, "msg": str(e)}), 500 - finally: - conn.close() - - @app.route("/api/hub/roll/sync-flat", methods=["POST"]) - @_hub_auth_required - def api_hub_roll_sync_flat(): - """中控/实例手动平仓后:取消滚仓 pending 并关闭 active 滚仓组。""" - body = request.get_json(silent=True) or {} - symbol = (body.get("symbol") or request.form.get("symbol") or "").strip() - side = ( - body.get("side") - or body.get("direction") - or request.form.get("side") - or "" - ).strip().lower() - if not symbol: - return jsonify({"ok": False, "msg": "symbol 不能为空"}), 400 - if side not in ("long", "short"): - return jsonify({"ok": False, "msg": "side 须为 long 或 short"}), 400 - cfg = current_app.extensions.get("strategy_roll_cfg") - get_db = _ctx().get("get_db") - if not cfg or not callable(get_db): - return jsonify({"ok": False, "msg": "滚仓配置未就绪"}), 500 - from strategy_register import roll_sync_after_external_close - - conn = get_db() - try: - out = roll_sync_after_external_close(cfg, conn, symbol, side) - conn.commit() - return jsonify(out) - except Exception as e: - return jsonify({"ok": False, "msg": str(e)}), 500 - finally: - conn.close() - - @app.route("/api/hub/trend/breakeven/", methods=["POST"]) - @_hub_auth_required - def api_hub_trend_breakeven(pid): - if not _ctx().get("has_trend"): - return jsonify({"ok": False, "msg": "该实例无趋势回调"}), 400 - body = request.get_json(silent=True) or {} - raw = (request.form.get("breakeven_offset_pct") or body.get("breakeven_offset_pct") or "").strip() - form = {} - if raw != "": - form["breakeven_offset_pct"] = raw - return jsonify( - _invoke_view( - "trend_pullback_breakeven", - f"/trend_pullback_breakeven/{pid}", - form=form, - ) - ) - - @app.route("/hub-sso") - def hub_sso_login(): - """中控签发的临时链接:写入 session 后跳转,直链访问仍走 /login。""" - from urllib.parse import urlencode - - auth_disabled = bool(current_app.config.get("HUB_AUTH_DISABLED")) - next_arg = request.args.get("next") - if auth_disabled: - session["logged_in"] = True - return redirect(safe_next_path(next_arg)) - ex = str((_ctx().get("exchange") or "")).strip().lower() - token = (request.args.get("token") or "").strip() - ok, next_path, err = verify_hub_sso_token(token, ex) - if ok: - embed_on = request.args.get("embed", "").strip().lower() in ( - "1", - "true", - "yes", - "on", - ) - dest_next = _embed_login_dest(next_path) if embed_on else next_path - if not embed_on: - ht = (request.args.get("hub_theme") or "").strip().lower() - if ht in ("light", "dark"): - dest_next = _merge_query_into_path(next_path, hub_theme=ht) - if embed_on and _sso_wants_embed_auth() and request.is_secure: - boot = mint_hub_embed_bootstrap(ex, dest_next) - if boot: - from urllib.parse import urlencode as _ue - - qdict = {"t": boot, "next": dest_next, "embed": "1"} - ht0 = (request.args.get("hub_theme") or "").strip().lower() - if ht0 in ("light", "dark"): - qdict["hub_theme"] = ht0 - return redirect(f"/hub-embed-auth?{_ue(qdict)}") - session["logged_in"] = True - session.modified = True - return redirect(dest_next) - hint = err or "校验失败" - flash( - f"中控 SSO 未生效({hint})。" - "请确认中控与实例 .env 中 HUB_BRIDGE_TOKEN 一致," - f"且中控设置里该账户 key 为「{ex}」。" - "经本地导航 iframe 打开时,实例须 HTTPS 且可设 APP_COOKIE_SECURE=true。" - ) - return redirect("/login") - - @app.route("/hub-embed-auth") - def hub_embed_auth_login(): - """LocalNav 等 iframe 内嵌:单独写入 SameSite=None 会话后跳转。""" - auth_disabled = bool(current_app.config.get("HUB_AUTH_DISABLED")) - next_arg = request.args.get("next") - if auth_disabled: - session["logged_in"] = True - return redirect(safe_next_path(next_arg)) - ex = str((_ctx().get("exchange") or "")).strip().lower() - boot = (request.args.get("t") or "").strip() - ok, next_path, err = verify_hub_embed_bootstrap(boot, ex) - if ok: - session["logged_in"] = True - session.modified = True - return redirect(_embed_login_dest(next_path)) - hint = err or "校验失败" - flash(f"iframe 登录未生效({hint})。可点本地导航工具栏「实例免密」重试。") - return redirect("/login") - - -def _latest_preview_id(): - get_db = _ctx().get("get_db") - if not get_db: - return None - conn = get_db() - row = conn.execute( - "SELECT id FROM trend_pullback_previews ORDER BY created_at DESC LIMIT 1" - ).fetchone() - conn.close() - return row["id"] if row else None - - -def _fetch_preview(pid): - get_db = _ctx().get("get_db") - if not get_db or not pid: - return None - conn = get_db() - row = conn.execute( - "SELECT * FROM trend_pullback_previews WHERE id=?", (pid,) - ).fetchone() - conn.close() - if not row: - return None - d = _row_to_dict(row) - now_ms = int(time.time() * 1000) - d["expires_in_sec"] = max(0, int((int(d.get("expires_at_ms") or 0) - now_ms) / 1000)) - try: - from strategy_trend_lib import build_trend_preview_level_rows - - enriched, level_rows = build_trend_preview_level_rows(d) - for key in ( - "preview_target_rr", - "preview_first_take_profit", - "preview_unified_stop_loss", - "preview_risk_amount_u", - "preview_first_profit_u", - "preview_take_profit_price", - ): - if key in enriched: - d[key] = enriched[key] - d["preview_level_rows"] = level_rows - d["grid_levels"] = [ - { - "i": row.get("i"), - "label": row.get("label"), - "price": row.get("price"), - "contracts": row.get("contracts"), - "cum_contracts": row.get("cum_contracts"), - "avg_entry": row.get("avg_entry"), - "take_profit_price": row.get("take_profit_price"), - "profit_u": row.get("profit_u"), - "risk_u": row.get("risk_u"), - "rr": row.get("rr"), - "stop_loss_price": row.get("stop_loss_price"), - "take_profit": row.get("profit_u"), - "stop_loss": row.get("risk_u"), - } - for row in level_rows - ] - except Exception: - d["grid_levels"] = [] - d["preview_level_rows"] = [] - return d +""" +各 crypto_monitor_* 注册 /api/hub/* JSON 接口,供 manual_trading_hub 调用。 +实例末尾:app.config["HUB_CTX"] = {...}; register_hub_routes(app) +""" + +from __future__ import annotations + +import json +import time +from functools import wraps + +from flask import ( + current_app, + flash, + get_flashed_messages, + jsonify, + redirect, + request, + session, +) + +from lib.hub.hub_auth import request_allowed +from lib.hub.hub_sso import ( + mint_hub_embed_bootstrap, + safe_next_path, + verify_hub_embed_bootstrap, + verify_hub_sso_token, +) + + +def _merge_query_into_path(path: str, **params: str) -> str: + from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit + + split = urlsplit(path or "/") + q = list(parse_qsl(split.query, keep_blank_values=True)) + keys = {k for k, _ in q} + for k, v in params.items(): + if not v or k in keys: + continue + q.append((k, str(v))) + return urlunsplit((split.scheme, split.netloc, split.path, urlencode(q), split.fragment)) + + +def install_instance_theme_static(app) -> None: + """仓库 lib/common/static 下 instance_theme.* 等供四所页面共用。""" + import os + + from flask import Response, send_file + + from lib.paths import common_static_dir + + repo_static = common_static_dir() + assets = { + "instance_theme.js": "application/javascript; charset=utf-8", + "instance_theme_early.css": "text/css; charset=utf-8", + "instance_theme.css": "text/css; charset=utf-8", + "account_risk_badge.css": "text/css; charset=utf-8", + "account_risk_badge.js": "application/javascript; charset=utf-8", + "instance_ui.js": "application/javascript; charset=utf-8", + "instance_records_mobile.js": "application/javascript; charset=utf-8", + "ai_review_render.js": "application/javascript; charset=utf-8", + "form_submit_guard.js": "application/javascript; charset=utf-8", + "key_monitor_form.js": "application/javascript; charset=utf-8", + "time_close_ui.js": "application/javascript; charset=utf-8", + "manual_order_rr_preview.js": "application/javascript; charset=utf-8", + "strategy_roll.js": "application/javascript; charset=utf-8", + "instance_page.css": "text/css; charset=utf-8", + "instance_embed.js": "application/javascript; charset=utf-8", + "focus_chart_page.js": "application/javascript; charset=utf-8", + "focus_chart_page.css": "text/css; charset=utf-8", + "trade_stats_calendar.js": "application/javascript; charset=utf-8", + "trade_stats_calendar.css": "text/css; charset=utf-8", + } + + for name, mime in assets.items(): + path = os.path.join(repo_static, name) + + def _view(p=path, m=mime): + if not os.path.isfile(p): + return Response("not found", status=404, mimetype="text/plain; charset=utf-8") + return send_file(p, mimetype=m) + + app.add_url_rule( + f"/static/{name}", + endpoint=f"repo_static_{name.replace('.', '_')}", + view_func=_view, + ) + + +def register_trade_stats_calendar_route( + app, + *, + login_required_fn, + load_pnls_fn, + row_matches_segment_fn, + reset_hour: int, + get_db_fn=None, +): + """四所统计分析页:按月返回各交易日盈亏/笔数。""" + from flask import jsonify, request + + from lib.trade.trade_stats_calendar_lib import build_trade_stats_calendar + + @app.route("/api/stats/calendar") + @login_required_fn + def api_stats_calendar(): + year = request.args.get("year", type=int) + month = request.args.get("month", type=int) + segment = (request.args.get("segment") or "all").strip() or "all" + if not year or not month: + from datetime import datetime + + now = datetime.now() + year = year or now.year + month = month or now.month + get_db = get_db_fn or (app.config.get("HUB_CTX") or {}).get("get_db") + if not get_db: + return jsonify({"ok": False, "msg": "未配置数据库"}), 500 + conn = get_db() + try: + pnls = load_pnls_fn(conn) + finally: + conn.close() + try: + payload = build_trade_stats_calendar( + pnls, + year, + month, + segment, + row_matches_segment_fn, + reset_hour=int(reset_hour), + ) + except ValueError as exc: + return jsonify({"ok": False, "msg": str(exc)}), 400 + return jsonify({"ok": True, **payload}) + + +def _hub_auth_required(f): + @wraps(f) + def wrapped(*args, **kwargs): + from flask import current_app as cap + + auth_disabled = bool(cap.config.get("HUB_AUTH_DISABLED")) + if not request_allowed(bool(session.get("logged_in")), auth_disabled): + return jsonify({"ok": False, "msg": "未授权(登录或 HUB_BRIDGE_TOKEN)"}), 401 + return f(*args, **kwargs) + + return wrapped + + +def _ctx(): + return current_app.config.get("HUB_CTX") or {} + + +def _row_to_dict(row): + fn = _ctx().get("row_to_dict") + if fn and row is not None: + return fn(row) + return dict(row) if row is not None else {} + + +def build_hub_monitor_payload( + *, + keys, + orders, + trends, + rolls, + enrich=None, + risk_status=None, +) -> dict: + """合并 enrich 增量字段;enrich 只返回 trends 等局部时不得丢掉 keys/orders。""" + payload = { + "ok": True, + "keys": keys, + "orders": orders, + "trends": trends, + "rolls": rolls, + "key_prices": [], + } + if isinstance(risk_status, dict): + payload["risk_status"] = risk_status + if callable(enrich): + extra = enrich(keys=keys, orders=orders, trends=trends, rolls=rolls) + if isinstance(extra, dict): + payload.update(extra) + return payload + + +_FAIL_HINTS = ( + "失败", + "错误", + "拒绝", + "无效", + "缺少", + "无法", + "过期", + "未达", + "不能为空", + "已有", + "不允许", + "异常", +) + + +def _invoke_view(view_name: str, path: str, form=None) -> dict: + views = _ctx().get("views") or {} + view = views.get(view_name) + if not view: + return {"ok": False, "messages": [f"未配置视图 {view_name}"]} + data = form if form is not None else request.form + if hasattr(data, "items") and not isinstance(data, dict): + data = {k: v for k, v in data.items()} + with current_app.test_request_context(path, method="POST", data=data): + session["logged_in"] = True + try: + view() + except Exception as e: + return {"ok": False, "messages": [str(e)]} + try: + msgs = [str(x) for x in get_flashed_messages()] + except Exception as e: + return {"ok": False, "messages": [f"读取提示信息失败: {e}"]} + ok = True + for m in msgs: + if any(k in m for k in _FAIL_HINTS): + ok = False + break + return {"ok": ok, "messages": msgs} + + +def _invoke_view_get(view_name: str, path: str) -> dict: + views = _ctx().get("views") or {} + view = views.get(view_name) + if not view: + return {"ok": False, "messages": [f"未配置视图 {view_name}"]} + with current_app.test_request_context(path, method="GET"): + session["logged_in"] = True + try: + view() + except Exception as e: + return {"ok": False, "messages": [str(e)]} + try: + msgs = [str(x) for x in get_flashed_messages()] + except Exception as e: + return {"ok": False, "messages": [f"读取提示信息失败: {e}"]} + ok = True + for m in msgs: + if any(k in m for k in _FAIL_HINTS): + ok = False + break + return {"ok": ok, "messages": msgs} + + +def _hub_json(view_name: str, path: str, form=None): + try: + return jsonify(_invoke_view(view_name, path, form=form)) + except Exception as e: + return jsonify({"ok": False, "messages": [str(e)]}) + + +def _embed_login_dest(next_path: str) -> str: + """embed=1 时把 /trade 等映射到 /embed?tab=…""" + ht = (request.args.get("hub_theme") or "").strip().lower() + hub_theme = ht if ht in ("light", "dark") else None + if request.args.get("embed", "").strip().lower() in ("1", "true", "yes", "on"): + from lib.instance.instance_embed_lib import rewrite_embed_dest + + return rewrite_embed_dest(next_path, hub_theme=hub_theme) + if hub_theme: + return _merge_query_into_path(next_path, hub_theme=hub_theme) + return next_path + + +def install_on_app( + app, + *, + exchange: str, + capabilities: list, + has_trend: bool, + get_db, + row_to_dict, + meta_fn, + views: dict, + ohlcv_fn=None, + account_fn=None, + volume_rank_fn=None, + market_fn=None, + reconcile_hub_flat_fn=None, + risk_status_fn=None, + user_close_fn=None, + render_main_page_fn=None, + login_required_fn=None, +): + app.config["HUB_CTX"] = { + "exchange": exchange, + "capabilities": list(capabilities), + "has_trend": bool(has_trend), + "get_db": get_db, + "row_to_dict": row_to_dict, + "meta_fn": meta_fn, + "account_fn": account_fn, + "views": views, + "ohlcv_fn": ohlcv_fn, + "volume_rank_fn": volume_rank_fn, + "market_fn": market_fn, + "reconcile_hub_flat_fn": reconcile_hub_flat_fn, + "risk_status_fn": risk_status_fn, + "user_close_fn": user_close_fn, + } + install_hub_embed_headers(app) + configure_hub_embed_session(app) + install_instance_theme_static(app) + register_hub_routes(app) + if render_main_page_fn and login_required_fn: + import os + + from lib.instance.instance_embed_lib import attach_embed_templates, register_embed_routes + + repo_root = os.path.dirname(os.path.abspath(__file__)) + attach_embed_templates(app, repo_root) + register_embed_routes(app, login_required_fn, render_main_page_fn) + + +def configure_hub_embed_session(app): + """HTTPS iframe 内嵌须 SameSite=None + Secure;hub-sso / hub-embed-auth 自动启用。""" + import os + + allowed = (os.getenv("APP_ALLOW_HUB_EMBED") or "true").strip().lower() in ( + "1", + "true", + "yes", + "on", + ) + if not allowed: + return + + secure_env = (os.getenv("APP_COOKIE_SECURE") or "auto").strip().lower() + if secure_env in ("1", "true", "yes", "on"): + app.config.update( + SESSION_COOKIE_SECURE=True, + SESSION_COOKIE_SAMESITE="None", + SESSION_COOKIE_HTTPONLY=True, + ) + return + + @app.before_request + def _hub_embed_session_cookie(): + if request.path not in ("/hub-sso", "/hub-embed-auth"): + return + embed = (request.args.get("embed") or "").strip().lower() in ( + "1", + "true", + "yes", + "on", + ) + in_iframe = (request.headers.get("Sec-Fetch-Dest") or "").lower() == "iframe" + if not embed and not in_iframe: + return + if not request.is_secure: + return + app.config["SESSION_COOKIE_SECURE"] = True + app.config["SESSION_COOKIE_SAMESITE"] = "None" + app.config["SESSION_COOKIE_HTTPONLY"] = True + + +def _sso_wants_embed_auth() -> bool: + embed = (request.args.get("embed") or "").strip().lower() in ( + "1", + "true", + "yes", + "on", + ) + in_iframe = (request.headers.get("Sec-Fetch-Dest") or "").lower() == "iframe" + return bool(embed or in_iframe) + + +def install_hub_embed_headers(app): + """允许复盘中控 iframe 内嵌打开本实例(须与 hub 的 HUB_EMBED_ORIGINS 或域名一致)。""" + import os + + allowed = (os.getenv("APP_ALLOW_HUB_EMBED") or "true").strip().lower() in ( + "1", + "true", + "yes", + "on", + ) + if not allowed: + return + origins = ( + (os.getenv("HUB_EMBED_PARENT_ORIGINS") or os.getenv("HUB_EMBED_ORIGINS") or "*") + .strip() + ) + + @app.after_request + def _hub_embed_frame_headers(response): + if origins == "*": + response.headers["Content-Security-Policy"] = "frame-ancestors *" + else: + response.headers["Content-Security-Policy"] = ( + f"frame-ancestors 'self' {origins}" + ) + return response + + +def register_hub_routes(app): + auth_disabled = False + try: + import os + + auth_disabled = os.getenv("APP_AUTH_DISABLED", "false").lower() in ( + "1", + "true", + "yes", + "on", + ) + except Exception: + pass + app.config.setdefault("HUB_AUTH_DISABLED", auth_disabled) + + @app.route("/api/hub/ping") + @_hub_auth_required + def api_hub_ping(): + c = _ctx() + return jsonify( + { + "ok": True, + "exchange": c.get("exchange"), + "capabilities": c.get("capabilities") or [], + } + ) + + @app.route("/api/hub/meta") + @_hub_auth_required + def api_hub_meta(): + c = _ctx() + meta_fn = c.get("meta_fn") + meta = meta_fn() if callable(meta_fn) else {} + return jsonify({"ok": True, "meta": meta}) + + @app.route("/api/hub/account") + @_hub_auth_required + def api_hub_account(): + """中控 AI:资金账户 / 交易账户余额(无需浏览器登录)。""" + fn = _ctx().get("account_fn") + if not callable(fn): + return jsonify({"ok": False, "msg": "未配置 account_fn"}), 501 + try: + data = fn() + if not isinstance(data, dict): + data = {} + return jsonify({"ok": True, **data}) + except Exception as e: + return jsonify({"ok": False, "msg": str(e)}), 500 + + @app.route("/api/account_risk_status") + @_hub_auth_required + def api_account_risk_status(): + c = _ctx() + get_db = c.get("get_db") + risk_fn = c.get("risk_status_fn") + if not callable(get_db) or not callable(risk_fn): + return jsonify({"ok": False, "msg": "未配置风控"}), 501 + conn = get_db() + try: + payload = risk_fn(conn) + return jsonify({"ok": True, **(payload if isinstance(payload, dict) else {})}) + except Exception as e: + return jsonify({"ok": False, "msg": str(e)}), 500 + finally: + conn.close() + + @app.route("/api/hub/account-risk/user-close", methods=["POST"]) + @_hub_auth_required + def api_hub_account_risk_user_close(): + """中控/实例:登记用户主动平仓(计入冷静期与日冻结)。""" + c = _ctx() + get_db = c.get("get_db") + user_close_fn = c.get("user_close_fn") + if not callable(get_db) or not callable(user_close_fn): + return jsonify({"ok": False, "msg": "未配置 user_close_fn"}), 501 + body = request.get_json(silent=True) or {} + source = (body.get("source") or request.form.get("source") or "").strip() + try: + count = max(0, int(body.get("count") if body.get("count") is not None else 1)) + except (TypeError, ValueError): + count = 1 + trade_record_id = body.get("trade_record_id") + closed_at_ms = body.get("closed_at_ms") + if count <= 0: + return jsonify({"ok": True, "skipped": True, "count": 0}) + conn = get_db() + try: + user_close_fn( + conn, + source=source, + count=count, + trade_record_id=trade_record_id, + closed_at_ms=closed_at_ms, + ) + conn.commit() + return jsonify({"ok": True, "count": count, "source": source}) + except Exception as e: + return jsonify({"ok": False, "msg": str(e)}), 500 + finally: + conn.close() + + @app.route("/api/hub/monitor") + @_hub_auth_required + def api_hub_monitor(): + c = _ctx() + get_db = c.get("get_db") + if not get_db: + return jsonify({"ok": False, "msg": "HUB_CTX 缺少 get_db"}), 500 + conn = get_db() + keys = [] + for row in conn.execute("SELECT * FROM key_monitors ORDER BY id DESC").fetchall(): + keys.append(_row_to_dict(row)) + orders = [] + for row in conn.execute( + "SELECT * FROM order_monitors WHERE status='active' ORDER BY id DESC" + ).fetchall(): + od = _row_to_dict(row) + try: + from lib.strategy.strategy_trade_labels import apply_order_monitor_source_labels + + od = apply_order_monitor_source_labels(od) + except Exception: + pass + orders.append(od) + trends = [] + if c.get("has_trend"): + for row in conn.execute( + "SELECT * FROM trend_pullback_plans WHERE status='active' ORDER BY id DESC" + ).fetchall(): + trends.append(_row_to_dict(row)) + rolls = [] + try: + for row in conn.execute( + """SELECT g.* FROM roll_groups g + INNER JOIN order_monitors m ON m.id = g.order_monitor_id AND m.status='active' + WHERE g.status='active' ORDER BY g.id DESC""" + ).fetchall(): + rolls.append(_row_to_dict(row)) + except Exception: + pass + risk_status = None + risk_fn = c.get("risk_status_fn") + if callable(risk_fn): + try: + risk_status = risk_fn(conn) + except Exception: + risk_status = None + conn.close() + enrich = c.get("enrich_monitor") + if callable(enrich): + try: + return jsonify( + build_hub_monitor_payload( + keys=keys, + orders=orders, + trends=trends, + rolls=rolls, + enrich=enrich, + risk_status=risk_status, + ) + ) + except Exception as e: + return jsonify({"ok": False, "msg": str(e)}), 500 + return jsonify( + build_hub_monitor_payload( + keys=keys, + orders=orders, + trends=trends, + rolls=rolls, + risk_status=risk_status, + ) + ) + + @app.route("/api/hub/trades/archive") + @_hub_auth_required + def api_hub_trades_archive(): + """中控币种档案:近 N 天已平仓记录。""" + from lib.hub.hub_trades_lib import fetch_trades_for_archive, summarize_trades + + c = _ctx() + get_db = c.get("get_db") + if not get_db: + return jsonify({"ok": False, "msg": "HUB_CTX 缺少 get_db"}), 500 + try: + days = int(request.args.get("days") or "365") + except ValueError: + days = 365 + try: + limit = int(request.args.get("limit") or "2000") + except ValueError: + limit = 2000 + try: + import os + + reset_hour = int(os.getenv("TRADING_DAY_RESET_HOUR", "8") or "8") + except ValueError: + reset_hour = 8 + conn = get_db() + try: + trades = fetch_trades_for_archive( + conn, + exchange_key=str(c.get("exchange") or ""), + days=days, + row_to_dict_fn=c.get("row_to_dict"), + reset_hour=reset_hour, + limit=limit, + ) + finally: + conn.close() + stats = summarize_trades(trades) + return jsonify( + { + "ok": True, + "days": max(1, min(days, 3650)), + "trading_day_reset_hour": reset_hour, + "trades": trades, + "stats": stats, + } + ) + + @app.route("/api/hub/trades/today") + @_hub_auth_required + def api_hub_trades_today(): + """中控 AI:当日已平仓记录(按实例交易日)。""" + from lib.hub.hub_trades_lib import ( + current_trading_day, + fetch_trades_for_trading_day, + summarize_trades, + ) + + c = _ctx() + get_db = c.get("get_db") + if not get_db: + return jsonify({"ok": False, "msg": "HUB_CTX 缺少 get_db"}), 500 + day_arg = (request.args.get("trading_day") or request.args.get("date") or "").strip()[:10] + try: + import os + + reset_hour = int(os.getenv("TRADING_DAY_RESET_HOUR", "8") or "8") + except ValueError: + reset_hour = 8 + trading_day = day_arg or current_trading_day(reset_hour=reset_hour) + conn = get_db() + try: + trades = fetch_trades_for_trading_day( + conn, + trading_day, + row_to_dict_fn=c.get("row_to_dict"), + reset_hour=reset_hour, + ) + finally: + conn.close() + stats = summarize_trades(trades) + return jsonify( + { + "ok": True, + "trading_day": trading_day, + "trading_day_reset_hour": reset_hour, + "trades": trades, + "stats": stats, + } + ) + + @app.route("/api/hub/volume-rank") + @_hub_auth_required + def api_hub_volume_rank(): + fn = _ctx().get("volume_rank_fn") + if not callable(fn): + return jsonify({"ok": False, "msg": "该实例未配置成交量排名接口"}), 501 + top_raw = (request.args.get("top") or "").strip() + top_n = 20 + if top_raw.isdigit(): + top_n = int(top_raw) + try: + result = fn(top_n=top_n) + if isinstance(result, dict): + return jsonify(result) + return jsonify({"ok": False, "msg": "成交量排名返回格式无效"}), 500 + except Exception as e: + return jsonify({"ok": False, "msg": str(e)}), 500 + + @app.route("/api/hub/market") + @_hub_auth_required + def api_hub_market(): + fn = _ctx().get("market_fn") + if not callable(fn): + return jsonify({"ok": False, "msg": "该实例未配置合约信息接口"}), 501 + base = (request.args.get("base") or request.args.get("symbol") or "").strip() + try: + result = fn(base=base) + if isinstance(result, dict): + return jsonify(result) + return jsonify({"ok": False, "msg": "合约信息返回格式无效"}), 500 + except Exception as e: + return jsonify({"ok": False, "msg": str(e)}), 500 + + @app.route("/api/hub/ohlcv") + @_hub_auth_required + def api_hub_ohlcv(): + fn = _ctx().get("ohlcv_fn") + if not callable(fn): + return jsonify({"ok": False, "msg": "该实例未配置 OHLCV 接口"}), 501 + symbol = (request.args.get("symbol") or "").strip() + timeframe = (request.args.get("timeframe") or "5m").strip() + since_raw = (request.args.get("since_ms") or "").strip() + limit_raw = (request.args.get("limit") or "").strip() + since_ms = None + if since_raw.isdigit(): + since_ms = int(since_raw) + limit = 500 + if limit_raw.isdigit(): + limit = int(limit_raw) + try: + result = fn(symbol=symbol, timeframe=timeframe, since_ms=since_ms, limit=limit) + if isinstance(result, dict): + return jsonify(result) + return jsonify({"ok": False, "msg": "OHLCV 返回格式无效"}), 500 + except Exception as e: + return jsonify({"ok": False, "msg": str(e)}), 500 + + @app.route("/api/hub/add_order", methods=["POST"]) + @_hub_auth_required + def api_hub_add_order(): + return _hub_json("add_order", "/add_order") + + @app.route("/api/hub/add_key", methods=["POST"]) + @_hub_auth_required + def api_hub_add_key(): + return _hub_json("add_key", "/add_key") + + @app.route("/api/hub/trend/preview", methods=["POST"]) + @_hub_auth_required + def api_hub_trend_preview(): + if not _ctx().get("has_trend"): + return jsonify({"ok": False, "msg": "该实例无趋势回调"}), 400 + data = _invoke_view("preview_trend_pullback", "/trade") + pid = _latest_preview_id() + preview = _fetch_preview(pid) if pid else None + return jsonify( + { + "ok": bool(data.get("ok")), + "messages": data.get("messages") or [], + "preview_id": pid, + "preview": preview, + } + ) + + @app.route("/api/hub/trend/execute", methods=["POST"]) + @_hub_auth_required + def api_hub_trend_execute(): + if not _ctx().get("has_trend"): + return jsonify({"ok": False, "msg": "该实例无趋势回调"}), 400 + pid = (request.form.get("preview_id") or "").strip() + if not pid: + body = request.get_json(silent=True) or {} + pid = str(body.get("preview_id") or "").strip() + form = {"preview_id": pid} if pid else {} + return jsonify(_invoke_view("execute_trend_pullback", "/trade", form=form)) + + @app.route("/api/hub/trend/preview/") + @_hub_auth_required + def api_hub_trend_preview_get(pid): + if not _ctx().get("has_trend"): + return jsonify({"ok": False, "msg": "该实例无趋势回调"}), 400 + preview = _fetch_preview(pid) + if not preview: + return jsonify({"ok": False, "msg": "预览不存在或已过期"}), 404 + return jsonify({"ok": True, "preview": preview}) + + @app.route("/api/hub/trend/stop/", methods=["POST"]) + @_hub_auth_required + def api_hub_trend_stop(pid): + if not _ctx().get("has_trend"): + return jsonify({"ok": False, "msg": "该实例无趋势回调"}), 400 + return jsonify(_invoke_view_get("stop_trend_pullback", f"/stop_trend_pullback/{pid}")) + + @app.route("/api/hub/order/sync-flat", methods=["POST"]) + @_hub_auth_required + def api_hub_order_sync_flat(): + """中控市价全平后:同步 order_monitors 并读 Gate 平仓历史写交易记录。""" + fn = _ctx().get("reconcile_hub_flat_fn") + if not callable(fn): + return jsonify({"ok": False, "msg": "该实例未配置 order sync-flat"}), 400 + body = request.get_json(silent=True) or {} + symbol = (body.get("symbol") or request.form.get("symbol") or "").strip() + side = ( + body.get("side") + or body.get("direction") + or request.form.get("side") + or "" + ).strip().lower() + if not symbol: + return jsonify({"ok": False, "msg": "symbol 不能为空"}), 400 + if side not in ("long", "short"): + return jsonify({"ok": False, "msg": "side 须为 long 或 short"}), 400 + get_db = _ctx().get("get_db") + if not callable(get_db): + return jsonify({"ok": False, "msg": "HUB_CTX 缺少 get_db"}), 500 + conn = get_db() + try: + out = fn(conn, symbol, side) + if not isinstance(out, dict): + out = {"ok": True, "synced": int(out or 0)} + conn.commit() + return jsonify(out) + except Exception as e: + return jsonify({"ok": False, "msg": str(e)}), 500 + finally: + conn.close() + + @app.route("/api/hub/trend/sync-flat", methods=["POST"]) + @_hub_auth_required + def api_hub_trend_sync_flat(): + """中控市价全平后:结束仍 active 的同币种同向趋势计划。""" + if not _ctx().get("has_trend"): + return jsonify({"ok": False, "msg": "该实例无趋势回调"}), 400 + body = request.get_json(silent=True) or {} + symbol = (body.get("symbol") or request.form.get("symbol") or "").strip() + side = ( + body.get("side") + or body.get("direction") + or request.form.get("side") + or "" + ).strip().lower() + if not symbol: + return jsonify({"ok": False, "msg": "symbol 不能为空"}), 400 + if side not in ("long", "short"): + return jsonify({"ok": False, "msg": "side 须为 long 或 short"}), 400 + cfg = current_app.extensions.get("strategy_trend_cfg") + get_db = _ctx().get("get_db") + if not cfg or not callable(get_db): + return jsonify({"ok": False, "msg": "趋势配置未就绪"}), 500 + from lib.strategy.strategy_trend_register import sync_trend_plans_after_external_close + + conn = get_db() + try: + return jsonify(sync_trend_plans_after_external_close(cfg, conn, symbol, side)) + except Exception as e: + return jsonify({"ok": False, "msg": str(e)}), 500 + finally: + conn.close() + + @app.route("/api/hub/roll/sync-flat", methods=["POST"]) + @_hub_auth_required + def api_hub_roll_sync_flat(): + """中控/实例手动平仓后:取消滚仓 pending 并关闭 active 滚仓组。""" + body = request.get_json(silent=True) or {} + symbol = (body.get("symbol") or request.form.get("symbol") or "").strip() + side = ( + body.get("side") + or body.get("direction") + or request.form.get("side") + or "" + ).strip().lower() + if not symbol: + return jsonify({"ok": False, "msg": "symbol 不能为空"}), 400 + if side not in ("long", "short"): + return jsonify({"ok": False, "msg": "side 须为 long 或 short"}), 400 + cfg = current_app.extensions.get("strategy_roll_cfg") + get_db = _ctx().get("get_db") + if not cfg or not callable(get_db): + return jsonify({"ok": False, "msg": "滚仓配置未就绪"}), 500 + from lib.strategy.strategy_register import roll_sync_after_external_close + + conn = get_db() + try: + out = roll_sync_after_external_close(cfg, conn, symbol, side) + conn.commit() + return jsonify(out) + except Exception as e: + return jsonify({"ok": False, "msg": str(e)}), 500 + finally: + conn.close() + + @app.route("/api/hub/trend/breakeven/", methods=["POST"]) + @_hub_auth_required + def api_hub_trend_breakeven(pid): + if not _ctx().get("has_trend"): + return jsonify({"ok": False, "msg": "该实例无趋势回调"}), 400 + body = request.get_json(silent=True) or {} + raw = (request.form.get("breakeven_offset_pct") or body.get("breakeven_offset_pct") or "").strip() + form = {} + if raw != "": + form["breakeven_offset_pct"] = raw + return jsonify( + _invoke_view( + "trend_pullback_breakeven", + f"/trend_pullback_breakeven/{pid}", + form=form, + ) + ) + + @app.route("/hub-sso") + def hub_sso_login(): + """中控签发的临时链接:写入 session 后跳转,直链访问仍走 /login。""" + from urllib.parse import urlencode + + auth_disabled = bool(current_app.config.get("HUB_AUTH_DISABLED")) + next_arg = request.args.get("next") + if auth_disabled: + session["logged_in"] = True + return redirect(safe_next_path(next_arg)) + ex = str((_ctx().get("exchange") or "")).strip().lower() + token = (request.args.get("token") or "").strip() + ok, next_path, err = verify_hub_sso_token(token, ex) + if ok: + embed_on = request.args.get("embed", "").strip().lower() in ( + "1", + "true", + "yes", + "on", + ) + dest_next = _embed_login_dest(next_path) if embed_on else next_path + if not embed_on: + ht = (request.args.get("hub_theme") or "").strip().lower() + if ht in ("light", "dark"): + dest_next = _merge_query_into_path(next_path, hub_theme=ht) + if embed_on and _sso_wants_embed_auth() and request.is_secure: + boot = mint_hub_embed_bootstrap(ex, dest_next) + if boot: + from urllib.parse import urlencode as _ue + + qdict = {"t": boot, "next": dest_next, "embed": "1"} + ht0 = (request.args.get("hub_theme") or "").strip().lower() + if ht0 in ("light", "dark"): + qdict["hub_theme"] = ht0 + return redirect(f"/hub-embed-auth?{_ue(qdict)}") + session["logged_in"] = True + session.modified = True + return redirect(dest_next) + hint = err or "校验失败" + flash( + f"中控 SSO 未生效({hint})。" + "请确认中控与实例 .env 中 HUB_BRIDGE_TOKEN 一致," + f"且中控设置里该账户 key 为「{ex}」。" + "经本地导航 iframe 打开时,实例须 HTTPS 且可设 APP_COOKIE_SECURE=true。" + ) + return redirect("/login") + + @app.route("/hub-embed-auth") + def hub_embed_auth_login(): + """LocalNav 等 iframe 内嵌:单独写入 SameSite=None 会话后跳转。""" + auth_disabled = bool(current_app.config.get("HUB_AUTH_DISABLED")) + next_arg = request.args.get("next") + if auth_disabled: + session["logged_in"] = True + return redirect(safe_next_path(next_arg)) + ex = str((_ctx().get("exchange") or "")).strip().lower() + boot = (request.args.get("t") or "").strip() + ok, next_path, err = verify_hub_embed_bootstrap(boot, ex) + if ok: + session["logged_in"] = True + session.modified = True + return redirect(_embed_login_dest(next_path)) + hint = err or "校验失败" + flash(f"iframe 登录未生效({hint})。可点本地导航工具栏「实例免密」重试。") + return redirect("/login") + + +def _latest_preview_id(): + get_db = _ctx().get("get_db") + if not get_db: + return None + conn = get_db() + row = conn.execute( + "SELECT id FROM trend_pullback_previews ORDER BY created_at DESC LIMIT 1" + ).fetchone() + conn.close() + return row["id"] if row else None + + +def _fetch_preview(pid): + get_db = _ctx().get("get_db") + if not get_db or not pid: + return None + conn = get_db() + row = conn.execute( + "SELECT * FROM trend_pullback_previews WHERE id=?", (pid,) + ).fetchone() + conn.close() + if not row: + return None + d = _row_to_dict(row) + now_ms = int(time.time() * 1000) + d["expires_in_sec"] = max(0, int((int(d.get("expires_at_ms") or 0) - now_ms) / 1000)) + try: + from lib.strategy.strategy_trend_lib import build_trend_preview_level_rows + + enriched, level_rows = build_trend_preview_level_rows(d) + for key in ( + "preview_target_rr", + "preview_first_take_profit", + "preview_unified_stop_loss", + "preview_risk_amount_u", + "preview_first_profit_u", + "preview_take_profit_price", + ): + if key in enriched: + d[key] = enriched[key] + d["preview_level_rows"] = level_rows + d["grid_levels"] = [ + { + "i": row.get("i"), + "label": row.get("label"), + "price": row.get("price"), + "contracts": row.get("contracts"), + "cum_contracts": row.get("cum_contracts"), + "avg_entry": row.get("avg_entry"), + "take_profit_price": row.get("take_profit_price"), + "profit_u": row.get("profit_u"), + "risk_u": row.get("risk_u"), + "rr": row.get("rr"), + "stop_loss_price": row.get("stop_loss_price"), + "take_profit": row.get("profit_u"), + "stop_loss": row.get("risk_u"), + } + for row in level_rows + ] + except Exception: + d["grid_levels"] = [] + d["preview_level_rows"] = [] + return d diff --git a/hub_calculator_lib.py b/lib/hub/hub_calculator_lib.py similarity index 96% rename from hub_calculator_lib.py rename to lib/hub/hub_calculator_lib.py index b4eed26..ea0ea39 100644 --- a/hub_calculator_lib.py +++ b/lib/hub/hub_calculator_lib.py @@ -1,498 +1,498 @@ -"""中控历史测算:趋势回调 / 滚仓,以损定仓(按交易所精度与张数规则)。""" -from __future__ import annotations - -from typing import Any, Callable, Optional, Tuple - -from strategy_roll_lib import max_roll_legs -from strategy_trend_lib import ( - build_trend_preview_level_rows, - calc_risk_fraction, - compute_trend_plan_core, - validate_trend_bounds, -) - -DEFAULT_DCA_LEGS = 5 -MARGIN_BUFFER = 0.95 - - -def _resolve_market( - exchange_id: str, - base: str, -) -> Tuple[Optional[dict[str, Any]], Optional[Callable[[float], Optional[float]]], Optional[str]]: - from hub_calculator_market_lib import get_calculator_market, make_amount_precise_fn_from_market - - market, err = get_calculator_market(exchange_id, base) - if err or not market: - return None, None, err or "无法解析合约" - amount_precise = make_amount_precise_fn_from_market(market) - return market, amount_precise, None - - -def calc_trend_calculator( - *, - direction: str, - capital_usdt: float, - risk_percent: float, - leverage: int, - entry_price: float, - stop_loss: float, - add_upper: float, - take_profit: float, - dca_legs: int = DEFAULT_DCA_LEGS, - exchange_id: str = "0", - base: str = "ETH", -) -> Tuple[Optional[dict[str, Any]], Optional[str]]: - market, amount_precise, merr = _resolve_market(exchange_id, base) - if merr or not market or not amount_precise: - return None, merr or "无法解析合约" - contract_size = float(market.get("contract_size") or 1.0) - exchange_symbol = market["exchange_symbol"] - - direction = (direction or "long").strip().lower() - if direction not in ("long", "short"): - return None, "方向须为 long 或 short" - try: - capital = float(capital_usdt) - rp = float(risk_percent) - lev = int(leverage) - entry = float(entry_price) - sl = float(stop_loss) - upper = float(add_upper) - tp = float(take_profit) - legs = max(1, int(dca_legs)) - cs = float(contract_size) if contract_size else 1.0 - except (TypeError, ValueError): - return None, "参数格式错误" - if capital <= 0 or rp <= 0 or lev <= 0 or entry <= 0 or sl <= 0 or upper <= 0 or tp <= 0: - return None, "资金、风险、杠杆与价格须大于 0" - - bound_err = validate_trend_bounds(direction, sl, upper) - if bound_err: - return None, bound_err - - rf = calc_risk_fraction(direction, upper, sl) - if rf is None or rf <= 0: - return None, "止损与补仓区间边界组合无法计算风险比例" - - risk_budget = capital * (rp / 100.0) - notional = risk_budget / rf - margin_plan = min(notional / float(lev), capital * MARGIN_BUFFER) - if margin_plan <= 0: - return None, "计划保证金过小" - - target_amt = _amount_from_margin(margin_plan, lev, entry, cs) - if target_amt is None or target_amt <= 0: - return None, "无法计算计划张数,请检查入场价与杠杆" - target_amt = amount_precise(target_amt) - if target_amt is None or target_amt <= 0: - return None, "计划张数低于交易所最小精度" - - def _amount_precise(_symbol: str, amount: float) -> Optional[float]: - return amount_precise(amount) - - payload, err = compute_trend_plan_core( - direction=direction, - stop_loss=sl, - add_upper=upper, - risk_percent=rp, - snapshot_usdt=capital, - leverage=lev, - live_price=entry, - target_order_amount=target_amt, - exchange_symbol=exchange_symbol, - dca_legs=legs, - amount_precise=_amount_precise, - min_amount=float(market.get("min_amount") or 0.0), - full_margin_buffer_ratio=MARGIN_BUFFER, - ) - if err: - return None, err - - payload["take_profit"] = tp - payload["leverage"] = lev - payload["contract_size"] = cs - preview, rows = build_trend_preview_level_rows(payload) - - px_dec = int(market.get("price_decimals") or 4) - amt_dec = int(market.get("amount_decimals") or 4) - - def _f(v: Any, nd: int | None = None) -> Any: - if v is None: - return None - try: - return round(float(v), nd if nd is not None else 8) - except (TypeError, ValueError): - return v - - table = [] - for row in rows: - table.append( - { - "label": row.get("label"), - "price": _f(row.get("price"), px_dec), - "contracts": _f(row.get("contracts"), amt_dec), - "avg_entry": _f(row.get("avg_entry"), px_dec), - "profit_u": _f(row.get("profit_u")), - "risk_u": _f(row.get("risk_u")), - "rr": _f(row.get("rr"), 4), - } - ) - - return { - "direction": direction, - "capital_usdt": _f(capital), - "risk_percent": _f(rp, 2), - "risk_budget_u": _f(preview.get("preview_risk_amount_u")), - "leverage": lev, - "entry_price": _f(entry, px_dec), - "stop_loss": _f(sl, px_dec), - "add_upper": _f(upper, px_dec), - "take_profit": _f(tp, px_dec), - "plan_margin_u": _f(preview.get("plan_margin_capital")), - "target_contracts": _f(preview.get("target_order_amount"), amt_dec), - "first_contracts": _f(preview.get("first_order_amount"), amt_dec), - "dca_legs": int(preview.get("dca_legs") or legs), - "first_profit_u": _f(preview.get("preview_first_profit_u")), - "first_rr": _f(preview.get("preview_target_rr"), 4), - "market": market, - "rows": table, - }, None - - -def _amount_from_margin( - margin_capital: float, - leverage: int, - price: float, - contract_size: float, -) -> Optional[float]: - try: - margin = float(margin_capital) - lev = int(leverage) - px = float(price) - cs = float(contract_size) if contract_size else 1.0 - except (TypeError, ValueError): - return None - if margin <= 0 or lev <= 0 or px <= 0 or cs <= 0: - return None - notional = margin * lev - return notional / (px * cs) - - -def _round(v: Any, nd: int = 4) -> Any: - if v is None: - return None - try: - return round(float(v), nd) - except (TypeError, ValueError): - return v - - -def _money_rr(profit_u: Optional[float], risk_u: Optional[float]) -> Optional[float]: - try: - if risk_u is None or float(risk_u) <= 0 or profit_u is None: - return None - return round(float(profit_u) / float(risk_u), 4) - except (TypeError, ValueError): - return None - - -def calc_initial_roll_qty( - direction: str, - entry_price: float, - stop_loss: float, - risk_budget_usdt: float, - contract_size: float = 1.0, -) -> Tuple[Optional[float], Optional[str]]: - """首仓以损定仓:打到初始止损亏损 = 风险预算。""" - try: - entry = float(entry_price) - sl = float(stop_loss) - budget = float(risk_budget_usdt) - cs = float(contract_size) if contract_size else 1.0 - except (TypeError, ValueError): - return None, "参数格式错误" - if entry <= 0 or sl <= 0 or budget <= 0 or cs <= 0: - return None, "入场价、止损与风险预算须大于 0" - direction = (direction or "long").strip().lower() - if direction == "short": - per_unit = (sl - entry) * cs - if per_unit <= 0: - return None, "做空:止损价须高于首仓入场价" - else: - per_unit = (entry - sl) * cs - if per_unit <= 0: - return None, "做多:止损价须低于首仓入场价" - return budget / per_unit, None - - -def solve_add_amount_for_total_risk( - direction: str, - qty_existing: float, - entry_existing: float, - add_price: float, - new_stop: float, - risk_budget_usdt: float, - contract_size: float = 1.0, -) -> Tuple[Optional[float], Optional[str]]: - """合并持仓打到新止损总亏损 = 风险预算,反推本次加仓张数。""" - try: - q1 = float(qty_existing) - e1 = float(entry_existing) - e2 = float(add_price) - sl = float(new_stop) - b = float(risk_budget_usdt) - cs = float(contract_size) if contract_size else 1.0 - except (TypeError, ValueError): - return None, "参数格式错误" - if q1 <= 0 or e1 <= 0 or e2 <= 0 or b <= 0 or cs <= 0: - return None, "持仓或风险预算无效" - direction = (direction or "long").strip().lower() - if direction == "short": - denom = sl - e2 - numer = b / cs - q1 * (sl - e1) - if denom <= 0: - return None, "做空:新止损须高于限价加仓价" - else: - denom = e2 - sl - numer = b / cs - q1 * (e1 - sl) - if denom <= 0: - return None, "做多:新止损须低于限价/市价加仓价" - q2 = numer / denom - if q2 <= 0: - return None, "按当前新止损与总风险%,无需加仓或无法再加(已满足风险上限)" - return q2, None - - -def _roll_leg_preview( - *, - direction: str, - qty_existing: float, - entry_existing: float, - take_profit: float, - add_price: float, - new_stop_loss: float, - risk_budget: float, - contract_size: float, - amount_precise: Callable[[float], Optional[float]], -) -> Tuple[Optional[dict[str, Any]], Optional[str]]: - direction = (direction or "long").strip().lower() - try: - tp = float(take_profit) - sl = float(new_stop_loss) - entry_add = float(add_price) - e1 = float(entry_existing) - except (TypeError, ValueError): - return None, "止损/止盈格式错误" - if sl <= 0 or tp <= 0 or entry_add <= 0: - return None, "止损与首仓止盈须大于0" - if direction == "long": - if sl >= entry_add: - return None, "做多:新止损须低于加仓价" - if tp <= e1: - return None, "做多:首仓止盈须高于当前持仓均价参考" - else: - if sl <= entry_add: - return None, "做空:新止损须高于加仓价" - if tp >= e1: - return None, "做空:首仓止盈须低于当前持仓均价参考" - - q2_raw, err = solve_add_amount_for_total_risk( - direction, - qty_existing, - entry_existing, - entry_add, - sl, - risk_budget, - contract_size, - ) - if err: - return None, err - q2 = amount_precise(float(q2_raw)) - if q2 is None or q2 <= 0: - return None, "加仓张数低于交易所最小精度" - new_qty = float(qty_existing) + float(q2) - new_avg = (float(qty_existing) * float(entry_existing) + float(q2) * entry_add) / new_qty - cs = float(contract_size) if contract_size else 1.0 - if direction == "long": - loss_at_sl = (new_avg - sl) * new_qty * cs - reward_at_tp = (tp - new_avg) * new_qty * cs - else: - loss_at_sl = (sl - new_avg) * new_qty * cs - reward_at_tp = (new_avg - tp) * new_qty * cs - return { - "add_amount_raw": q2, - "qty_after": new_qty, - "avg_entry_after": new_avg, - "add_price": entry_add, - "new_stop_loss": sl, - "loss_at_sl_usdt": loss_at_sl, - "reward_at_tp_usdt": reward_at_tp, - }, None - - -def calc_roll_calculator( - *, - direction: str, - capital_usdt: float, - risk_percent: float, - entry_price: float, - stop_loss: float, - take_profit: float, - add_legs: list[dict[str, float]] | None = None, - legs_done: int = 0, - exchange_id: str = "0", - base: str = "ETH", -) -> Tuple[Optional[dict[str, Any]], Optional[str]]: - """ - 滚仓历史测算:首仓自动以损定仓;止盈锁定首仓价;最多 3 次滚仓加仓。 - add_legs: [{add_price, new_stop_loss}, ...],按顺序链式计算。 - legs_done: 已完成滚仓次数(仅标记,仍参与链式状态推进)。 - """ - market, amount_precise, merr = _resolve_market(exchange_id, base) - if merr or not market or not amount_precise: - return None, merr or "无法解析合约" - contract_size = float(market.get("contract_size") or 1.0) - px_dec = int(market.get("price_decimals") or 4) - amt_dec = int(market.get("amount_decimals") or 4) - - direction = (direction or "long").strip().lower() - if direction not in ("long", "short"): - return None, "方向须为 long 或 short" - try: - capital = float(capital_usdt) - rp = float(risk_percent) - entry = float(entry_price) - initial_sl = float(stop_loss) - tp = float(take_profit) - done = max(0, int(legs_done)) - except (TypeError, ValueError): - return None, "参数格式错误" - if capital <= 0 or rp <= 0 or entry <= 0 or initial_sl <= 0 or tp <= 0: - return None, "资金、风险与价格须大于 0" - if done > max_roll_legs(direction): - return None, f"已完成滚仓次数不能超过 {max_roll_legs(direction)} 次" - - legs_in: list[dict[str, float]] = [] - for raw in add_legs or []: - if not isinstance(raw, dict): - continue - try: - ap = float(raw.get("add_price")) - nsl = float(raw.get("new_stop_loss")) - except (TypeError, ValueError): - return None, "加仓价与新止损须为有效数字" - if ap <= 0 or nsl <= 0: - return None, "加仓价与新止损须大于 0" - legs_in.append({"add_price": ap, "new_stop_loss": nsl}) - - if done + len(legs_in) > max_roll_legs(direction): - return None, f"已完成 {done} 次 + 待测算 {len(legs_in)} 次,合计不能超过 {max_roll_legs(direction)} 次滚仓" - - if direction == "long": - if tp <= entry: - return None, "做多:止盈价须高于首仓入场价" - else: - if tp >= entry: - return None, "做空:止盈价须低于首仓入场价" - - risk_budget = capital * (rp / 100.0) - qty, err = calc_initial_roll_qty(direction, entry, initial_sl, risk_budget, contract_size) - if err: - return None, err - if qty is None or qty <= 0: - return None, "无法计算首仓张数" - qty_p = amount_precise(float(qty)) - if qty_p is None or qty_p <= 0: - return None, "首仓张数低于交易所最小精度" - - qty_f = float(qty_p) - avg = entry - rows: list[dict[str, Any]] = [] - cs = contract_size - - if direction == "long": - first_loss = (avg - initial_sl) * qty_f * cs - first_profit = (tp - avg) * qty_f * cs - else: - first_loss = (initial_sl - avg) * qty_f * cs - first_profit = (avg - tp) * qty_f * cs - - rows.append( - { - "label": "首仓", - "leg_index": 0, - "already_done": False, - "entry_or_add_price": _round(entry, px_dec), - "stop_loss": _round(initial_sl, px_dec), - "add_contracts": _round(qty_f, amt_dec), - "total_contracts": _round(qty_f, amt_dec), - "avg_entry": _round(avg, px_dec), - "take_profit": _round(tp, px_dec), - "loss_at_sl_u": _round(first_loss), - "profit_at_tp_u": _round(first_profit), - "rr": _money_rr(first_profit, first_loss), - } - ) - - current_qty = qty_f - current_avg = avg - - for i, leg in enumerate(legs_in): - leg_no = i + 1 - preview, err = _roll_leg_preview( - direction=direction, - qty_existing=current_qty, - entry_existing=current_avg, - take_profit=tp, - add_price=leg["add_price"], - new_stop_loss=leg["new_stop_loss"], - risk_budget=risk_budget, - contract_size=cs, - amount_precise=amount_precise, - ) - if err: - return None, f"滚仓第 {leg_no} 次:{err}" - if not preview: - return None, f"滚仓第 {leg_no} 次计算失败" - - current_qty = float(preview["qty_after"]) - current_avg = float(preview["avg_entry_after"]) - loss = preview.get("loss_at_sl_usdt") - reward = preview.get("reward_at_tp_usdt") - rows.append( - { - "label": f"滚仓{leg_no}", - "leg_index": leg_no, - "already_done": leg_no <= done, - "entry_or_add_price": _round(preview.get("add_price"), px_dec), - "stop_loss": _round(preview.get("new_stop_loss"), px_dec), - "add_contracts": _round(preview.get("add_amount_raw"), amt_dec), - "total_contracts": _round(current_qty, amt_dec), - "avg_entry": _round(current_avg, px_dec), - "take_profit": _round(tp, px_dec), - "loss_at_sl_u": _round(loss), - "profit_at_tp_u": _round(reward), - "rr": _money_rr(reward, loss), - } - ) - - last = rows[-1] - return { - "direction": direction, - "capital_usdt": _round(capital), - "risk_percent": _round(rp, 2), - "risk_budget_u": _round(risk_budget), - "entry_price": _round(entry, px_dec), - "stop_loss": _round(initial_sl, px_dec), - "take_profit": _round(tp, px_dec), - "legs_done": done, - "roll_legs_planned": len(legs_in), - "first_contracts": _round(qty_f, amt_dec), - "final_contracts": last.get("total_contracts"), - "final_avg_entry": last.get("avg_entry"), - "final_loss_at_sl_u": last.get("loss_at_sl_u"), - "final_profit_at_tp_u": last.get("profit_at_tp_u"), - "final_rr": last.get("rr"), - "market": market, - "rows": rows, - }, None +"""中控历史测算:趋势回调 / 滚仓,以损定仓(按交易所精度与张数规则)。""" +from __future__ import annotations + +from typing import Any, Callable, Optional, Tuple + +from lib.strategy.strategy_roll_lib import max_roll_legs +from lib.strategy.strategy_trend_lib import ( + build_trend_preview_level_rows, + calc_risk_fraction, + compute_trend_plan_core, + validate_trend_bounds, +) + +DEFAULT_DCA_LEGS = 5 +MARGIN_BUFFER = 0.95 + + +def _resolve_market( + exchange_id: str, + base: str, +) -> Tuple[Optional[dict[str, Any]], Optional[Callable[[float], Optional[float]]], Optional[str]]: + from lib.hub.hub_calculator_market_lib import get_calculator_market, make_amount_precise_fn_from_market + + market, err = get_calculator_market(exchange_id, base) + if err or not market: + return None, None, err or "无法解析合约" + amount_precise = make_amount_precise_fn_from_market(market) + return market, amount_precise, None + + +def calc_trend_calculator( + *, + direction: str, + capital_usdt: float, + risk_percent: float, + leverage: int, + entry_price: float, + stop_loss: float, + add_upper: float, + take_profit: float, + dca_legs: int = DEFAULT_DCA_LEGS, + exchange_id: str = "0", + base: str = "ETH", +) -> Tuple[Optional[dict[str, Any]], Optional[str]]: + market, amount_precise, merr = _resolve_market(exchange_id, base) + if merr or not market or not amount_precise: + return None, merr or "无法解析合约" + contract_size = float(market.get("contract_size") or 1.0) + exchange_symbol = market["exchange_symbol"] + + direction = (direction or "long").strip().lower() + if direction not in ("long", "short"): + return None, "方向须为 long 或 short" + try: + capital = float(capital_usdt) + rp = float(risk_percent) + lev = int(leverage) + entry = float(entry_price) + sl = float(stop_loss) + upper = float(add_upper) + tp = float(take_profit) + legs = max(1, int(dca_legs)) + cs = float(contract_size) if contract_size else 1.0 + except (TypeError, ValueError): + return None, "参数格式错误" + if capital <= 0 or rp <= 0 or lev <= 0 or entry <= 0 or sl <= 0 or upper <= 0 or tp <= 0: + return None, "资金、风险、杠杆与价格须大于 0" + + bound_err = validate_trend_bounds(direction, sl, upper) + if bound_err: + return None, bound_err + + rf = calc_risk_fraction(direction, upper, sl) + if rf is None or rf <= 0: + return None, "止损与补仓区间边界组合无法计算风险比例" + + risk_budget = capital * (rp / 100.0) + notional = risk_budget / rf + margin_plan = min(notional / float(lev), capital * MARGIN_BUFFER) + if margin_plan <= 0: + return None, "计划保证金过小" + + target_amt = _amount_from_margin(margin_plan, lev, entry, cs) + if target_amt is None or target_amt <= 0: + return None, "无法计算计划张数,请检查入场价与杠杆" + target_amt = amount_precise(target_amt) + if target_amt is None or target_amt <= 0: + return None, "计划张数低于交易所最小精度" + + def _amount_precise(_symbol: str, amount: float) -> Optional[float]: + return amount_precise(amount) + + payload, err = compute_trend_plan_core( + direction=direction, + stop_loss=sl, + add_upper=upper, + risk_percent=rp, + snapshot_usdt=capital, + leverage=lev, + live_price=entry, + target_order_amount=target_amt, + exchange_symbol=exchange_symbol, + dca_legs=legs, + amount_precise=_amount_precise, + min_amount=float(market.get("min_amount") or 0.0), + full_margin_buffer_ratio=MARGIN_BUFFER, + ) + if err: + return None, err + + payload["take_profit"] = tp + payload["leverage"] = lev + payload["contract_size"] = cs + preview, rows = build_trend_preview_level_rows(payload) + + px_dec = int(market.get("price_decimals") or 4) + amt_dec = int(market.get("amount_decimals") or 4) + + def _f(v: Any, nd: int | None = None) -> Any: + if v is None: + return None + try: + return round(float(v), nd if nd is not None else 8) + except (TypeError, ValueError): + return v + + table = [] + for row in rows: + table.append( + { + "label": row.get("label"), + "price": _f(row.get("price"), px_dec), + "contracts": _f(row.get("contracts"), amt_dec), + "avg_entry": _f(row.get("avg_entry"), px_dec), + "profit_u": _f(row.get("profit_u")), + "risk_u": _f(row.get("risk_u")), + "rr": _f(row.get("rr"), 4), + } + ) + + return { + "direction": direction, + "capital_usdt": _f(capital), + "risk_percent": _f(rp, 2), + "risk_budget_u": _f(preview.get("preview_risk_amount_u")), + "leverage": lev, + "entry_price": _f(entry, px_dec), + "stop_loss": _f(sl, px_dec), + "add_upper": _f(upper, px_dec), + "take_profit": _f(tp, px_dec), + "plan_margin_u": _f(preview.get("plan_margin_capital")), + "target_contracts": _f(preview.get("target_order_amount"), amt_dec), + "first_contracts": _f(preview.get("first_order_amount"), amt_dec), + "dca_legs": int(preview.get("dca_legs") or legs), + "first_profit_u": _f(preview.get("preview_first_profit_u")), + "first_rr": _f(preview.get("preview_target_rr"), 4), + "market": market, + "rows": table, + }, None + + +def _amount_from_margin( + margin_capital: float, + leverage: int, + price: float, + contract_size: float, +) -> Optional[float]: + try: + margin = float(margin_capital) + lev = int(leverage) + px = float(price) + cs = float(contract_size) if contract_size else 1.0 + except (TypeError, ValueError): + return None + if margin <= 0 or lev <= 0 or px <= 0 or cs <= 0: + return None + notional = margin * lev + return notional / (px * cs) + + +def _round(v: Any, nd: int = 4) -> Any: + if v is None: + return None + try: + return round(float(v), nd) + except (TypeError, ValueError): + return v + + +def _money_rr(profit_u: Optional[float], risk_u: Optional[float]) -> Optional[float]: + try: + if risk_u is None or float(risk_u) <= 0 or profit_u is None: + return None + return round(float(profit_u) / float(risk_u), 4) + except (TypeError, ValueError): + return None + + +def calc_initial_roll_qty( + direction: str, + entry_price: float, + stop_loss: float, + risk_budget_usdt: float, + contract_size: float = 1.0, +) -> Tuple[Optional[float], Optional[str]]: + """首仓以损定仓:打到初始止损亏损 = 风险预算。""" + try: + entry = float(entry_price) + sl = float(stop_loss) + budget = float(risk_budget_usdt) + cs = float(contract_size) if contract_size else 1.0 + except (TypeError, ValueError): + return None, "参数格式错误" + if entry <= 0 or sl <= 0 or budget <= 0 or cs <= 0: + return None, "入场价、止损与风险预算须大于 0" + direction = (direction or "long").strip().lower() + if direction == "short": + per_unit = (sl - entry) * cs + if per_unit <= 0: + return None, "做空:止损价须高于首仓入场价" + else: + per_unit = (entry - sl) * cs + if per_unit <= 0: + return None, "做多:止损价须低于首仓入场价" + return budget / per_unit, None + + +def solve_add_amount_for_total_risk( + direction: str, + qty_existing: float, + entry_existing: float, + add_price: float, + new_stop: float, + risk_budget_usdt: float, + contract_size: float = 1.0, +) -> Tuple[Optional[float], Optional[str]]: + """合并持仓打到新止损总亏损 = 风险预算,反推本次加仓张数。""" + try: + q1 = float(qty_existing) + e1 = float(entry_existing) + e2 = float(add_price) + sl = float(new_stop) + b = float(risk_budget_usdt) + cs = float(contract_size) if contract_size else 1.0 + except (TypeError, ValueError): + return None, "参数格式错误" + if q1 <= 0 or e1 <= 0 or e2 <= 0 or b <= 0 or cs <= 0: + return None, "持仓或风险预算无效" + direction = (direction or "long").strip().lower() + if direction == "short": + denom = sl - e2 + numer = b / cs - q1 * (sl - e1) + if denom <= 0: + return None, "做空:新止损须高于限价加仓价" + else: + denom = e2 - sl + numer = b / cs - q1 * (e1 - sl) + if denom <= 0: + return None, "做多:新止损须低于限价/市价加仓价" + q2 = numer / denom + if q2 <= 0: + return None, "按当前新止损与总风险%,无需加仓或无法再加(已满足风险上限)" + return q2, None + + +def _roll_leg_preview( + *, + direction: str, + qty_existing: float, + entry_existing: float, + take_profit: float, + add_price: float, + new_stop_loss: float, + risk_budget: float, + contract_size: float, + amount_precise: Callable[[float], Optional[float]], +) -> Tuple[Optional[dict[str, Any]], Optional[str]]: + direction = (direction or "long").strip().lower() + try: + tp = float(take_profit) + sl = float(new_stop_loss) + entry_add = float(add_price) + e1 = float(entry_existing) + except (TypeError, ValueError): + return None, "止损/止盈格式错误" + if sl <= 0 or tp <= 0 or entry_add <= 0: + return None, "止损与首仓止盈须大于0" + if direction == "long": + if sl >= entry_add: + return None, "做多:新止损须低于加仓价" + if tp <= e1: + return None, "做多:首仓止盈须高于当前持仓均价参考" + else: + if sl <= entry_add: + return None, "做空:新止损须高于加仓价" + if tp >= e1: + return None, "做空:首仓止盈须低于当前持仓均价参考" + + q2_raw, err = solve_add_amount_for_total_risk( + direction, + qty_existing, + entry_existing, + entry_add, + sl, + risk_budget, + contract_size, + ) + if err: + return None, err + q2 = amount_precise(float(q2_raw)) + if q2 is None or q2 <= 0: + return None, "加仓张数低于交易所最小精度" + new_qty = float(qty_existing) + float(q2) + new_avg = (float(qty_existing) * float(entry_existing) + float(q2) * entry_add) / new_qty + cs = float(contract_size) if contract_size else 1.0 + if direction == "long": + loss_at_sl = (new_avg - sl) * new_qty * cs + reward_at_tp = (tp - new_avg) * new_qty * cs + else: + loss_at_sl = (sl - new_avg) * new_qty * cs + reward_at_tp = (new_avg - tp) * new_qty * cs + return { + "add_amount_raw": q2, + "qty_after": new_qty, + "avg_entry_after": new_avg, + "add_price": entry_add, + "new_stop_loss": sl, + "loss_at_sl_usdt": loss_at_sl, + "reward_at_tp_usdt": reward_at_tp, + }, None + + +def calc_roll_calculator( + *, + direction: str, + capital_usdt: float, + risk_percent: float, + entry_price: float, + stop_loss: float, + take_profit: float, + add_legs: list[dict[str, float]] | None = None, + legs_done: int = 0, + exchange_id: str = "0", + base: str = "ETH", +) -> Tuple[Optional[dict[str, Any]], Optional[str]]: + """ + 滚仓历史测算:首仓自动以损定仓;止盈锁定首仓价;最多 3 次滚仓加仓。 + add_legs: [{add_price, new_stop_loss}, ...],按顺序链式计算。 + legs_done: 已完成滚仓次数(仅标记,仍参与链式状态推进)。 + """ + market, amount_precise, merr = _resolve_market(exchange_id, base) + if merr or not market or not amount_precise: + return None, merr or "无法解析合约" + contract_size = float(market.get("contract_size") or 1.0) + px_dec = int(market.get("price_decimals") or 4) + amt_dec = int(market.get("amount_decimals") or 4) + + direction = (direction or "long").strip().lower() + if direction not in ("long", "short"): + return None, "方向须为 long 或 short" + try: + capital = float(capital_usdt) + rp = float(risk_percent) + entry = float(entry_price) + initial_sl = float(stop_loss) + tp = float(take_profit) + done = max(0, int(legs_done)) + except (TypeError, ValueError): + return None, "参数格式错误" + if capital <= 0 or rp <= 0 or entry <= 0 or initial_sl <= 0 or tp <= 0: + return None, "资金、风险与价格须大于 0" + if done > max_roll_legs(direction): + return None, f"已完成滚仓次数不能超过 {max_roll_legs(direction)} 次" + + legs_in: list[dict[str, float]] = [] + for raw in add_legs or []: + if not isinstance(raw, dict): + continue + try: + ap = float(raw.get("add_price")) + nsl = float(raw.get("new_stop_loss")) + except (TypeError, ValueError): + return None, "加仓价与新止损须为有效数字" + if ap <= 0 or nsl <= 0: + return None, "加仓价与新止损须大于 0" + legs_in.append({"add_price": ap, "new_stop_loss": nsl}) + + if done + len(legs_in) > max_roll_legs(direction): + return None, f"已完成 {done} 次 + 待测算 {len(legs_in)} 次,合计不能超过 {max_roll_legs(direction)} 次滚仓" + + if direction == "long": + if tp <= entry: + return None, "做多:止盈价须高于首仓入场价" + else: + if tp >= entry: + return None, "做空:止盈价须低于首仓入场价" + + risk_budget = capital * (rp / 100.0) + qty, err = calc_initial_roll_qty(direction, entry, initial_sl, risk_budget, contract_size) + if err: + return None, err + if qty is None or qty <= 0: + return None, "无法计算首仓张数" + qty_p = amount_precise(float(qty)) + if qty_p is None or qty_p <= 0: + return None, "首仓张数低于交易所最小精度" + + qty_f = float(qty_p) + avg = entry + rows: list[dict[str, Any]] = [] + cs = contract_size + + if direction == "long": + first_loss = (avg - initial_sl) * qty_f * cs + first_profit = (tp - avg) * qty_f * cs + else: + first_loss = (initial_sl - avg) * qty_f * cs + first_profit = (avg - tp) * qty_f * cs + + rows.append( + { + "label": "首仓", + "leg_index": 0, + "already_done": False, + "entry_or_add_price": _round(entry, px_dec), + "stop_loss": _round(initial_sl, px_dec), + "add_contracts": _round(qty_f, amt_dec), + "total_contracts": _round(qty_f, amt_dec), + "avg_entry": _round(avg, px_dec), + "take_profit": _round(tp, px_dec), + "loss_at_sl_u": _round(first_loss), + "profit_at_tp_u": _round(first_profit), + "rr": _money_rr(first_profit, first_loss), + } + ) + + current_qty = qty_f + current_avg = avg + + for i, leg in enumerate(legs_in): + leg_no = i + 1 + preview, err = _roll_leg_preview( + direction=direction, + qty_existing=current_qty, + entry_existing=current_avg, + take_profit=tp, + add_price=leg["add_price"], + new_stop_loss=leg["new_stop_loss"], + risk_budget=risk_budget, + contract_size=cs, + amount_precise=amount_precise, + ) + if err: + return None, f"滚仓第 {leg_no} 次:{err}" + if not preview: + return None, f"滚仓第 {leg_no} 次计算失败" + + current_qty = float(preview["qty_after"]) + current_avg = float(preview["avg_entry_after"]) + loss = preview.get("loss_at_sl_usdt") + reward = preview.get("reward_at_tp_usdt") + rows.append( + { + "label": f"滚仓{leg_no}", + "leg_index": leg_no, + "already_done": leg_no <= done, + "entry_or_add_price": _round(preview.get("add_price"), px_dec), + "stop_loss": _round(preview.get("new_stop_loss"), px_dec), + "add_contracts": _round(preview.get("add_amount_raw"), amt_dec), + "total_contracts": _round(current_qty, amt_dec), + "avg_entry": _round(current_avg, px_dec), + "take_profit": _round(tp, px_dec), + "loss_at_sl_u": _round(loss), + "profit_at_tp_u": _round(reward), + "rr": _money_rr(reward, loss), + } + ) + + last = rows[-1] + return { + "direction": direction, + "capital_usdt": _round(capital), + "risk_percent": _round(rp, 2), + "risk_budget_u": _round(risk_budget), + "entry_price": _round(entry, px_dec), + "stop_loss": _round(initial_sl, px_dec), + "take_profit": _round(tp, px_dec), + "legs_done": done, + "roll_legs_planned": len(legs_in), + "first_contracts": _round(qty_f, amt_dec), + "final_contracts": last.get("total_contracts"), + "final_avg_entry": last.get("avg_entry"), + "final_loss_at_sl_u": last.get("loss_at_sl_u"), + "final_profit_at_tp_u": last.get("profit_at_tp_u"), + "final_rr": last.get("rr"), + "market": market, + "rows": rows, + }, None diff --git a/hub_calculator_market_lib.py b/lib/hub/hub_calculator_market_lib.py similarity index 96% rename from hub_calculator_market_lib.py rename to lib/hub/hub_calculator_market_lib.py index 340c3fc..925c617 100644 --- a/hub_calculator_market_lib.py +++ b/lib/hub/hub_calculator_market_lib.py @@ -1,257 +1,257 @@ -"""计算器:从已配置交易实例读取 USDT 永续合约精度与张数规则。""" - -from __future__ import annotations - -import json -import threading -import time -import urllib.error -import urllib.request -from typing import Any, Callable, Optional, Tuple -from urllib.parse import urlencode - -try: - from settings_store import enabled_exchanges, load_settings -except ImportError: - from manual_trading_hub.settings_store import enabled_exchanges, load_settings - -MARKET_CACHE: dict[str, tuple[float, dict[str, Any]]] = {} -MARKET_LOCK = threading.Lock() -MARKET_TTL_SEC = 300.0 -HUB_FLASK_TIMEOUT = float(__import__("os").getenv("HUB_FLASK_TIMEOUT", "20")) - - -def normalize_base_symbol(text: str) -> str: - s = str(text or "").upper().strip() - for suf in ("USDT:USDT", "/USDT:USDT", "/USDT", "USDT", "-USDT-SWAP"): - if s.endswith(suf) and len(s) > len(suf): - s = s[: -len(suf)].strip("-/") - break - if "/" in s: - s = s.split("/", 1)[0].strip() - if ":" in s: - s = s.split(":", 1)[0].strip() - return s - - -def resolve_usdt_perp_symbol(exchange: Any, base: str) -> Tuple[Optional[str], Optional[str]]: - base_u = normalize_base_symbol(base) - if not base_u: - return None, "请输入币种,如 ETH" - candidates = [f"{base_u}/USDT:USDT", f"{base_u}/USDT"] - markets = getattr(exchange, "markets", None) or {} - for sym in candidates: - m = markets.get(sym) - if not m: - continue - if m.get("active") is False: - continue - if m.get("swap") or m.get("linear") or m.get("contract"): - return sym, None - for sym, m in markets.items(): - if m.get("active") is False: - continue - if not (m.get("swap") or m.get("linear")): - continue - if (m.get("quote") or "").upper() != "USDT": - continue - if (m.get("base") or "").upper() == base_u: - return sym, None - return None, f"未找到 {base_u}/USDT 永续合约" - - -def _decimals_from_precision_value(value: Any) -> Optional[int]: - if value in (None, ""): - return None - try: - p = float(value) - except (TypeError, ValueError): - return None - if p >= 1 and abs(p - round(p)) < 1e-9 and p <= 12: - return int(round(p)) - if 0 < p < 1: - s = f"{p:.12f}".rstrip("0") - if "." in s: - return min(12, len(s.split(".", 1)[1])) - return None - - -def _decimals_from_ccxt_str(text: str) -> int: - s = str(text or "").strip() - if not s or "." not in s: - return 0 - frac = s.split(".", 1)[1] - if not frac: - return 0 - return min(12, len(frac.rstrip("0") or frac)) - - -def amount_decimals_from_exchange(exchange: Any, exchange_symbol: str) -> int: - try: - return _decimals_from_ccxt_str(exchange.amount_to_precision(exchange_symbol, 1.23456789)) - except Exception: - market = exchange.market(exchange_symbol) - prec = (market.get("precision") or {}).get("amount") - d = _decimals_from_precision_value(prec) - return d if d is not None else 4 - - -def price_decimals_from_exchange( - exchange: Any, exchange_symbol: str, price_tick: Optional[float] -) -> int: - from hub_ohlcv_lib import normalize_price_tick - - tick = normalize_price_tick(price_tick) - if tick and tick > 0: - if tick >= 1: - return 0 - s = f"{tick:.12f}".rstrip("0") - if "." in s: - return min(12, len(s.split(".", 1)[1])) - try: - return _decimals_from_ccxt_str(exchange.price_to_precision(exchange_symbol, 12345.678901234)) - except Exception: - market = exchange.market(exchange_symbol) - prec = (market.get("precision") or {}).get("price") - d = _decimals_from_precision_value(prec) - return d if d is not None else 4 - - -def make_amount_precise_fn_from_market(market: dict[str, Any]) -> Callable[[float], Optional[float]]: - dec = max(0, int(market.get("amount_decimals") or 4)) - min_amt = market.get("min_amount") - - def _fn(amount: float) -> Optional[float]: - try: - v = float(amount) - except (TypeError, ValueError): - return None - if v <= 0: - return None - factor = 10**dec - v = int(v * factor + 1e-12) / factor - if min_amt is not None: - try: - if v < float(min_amt): - return None - except (TypeError, ValueError): - pass - if v <= 0: - return None - return v - - return _fn - - -def find_exchange(exchange_id: str) -> dict | None: - needle = str(exchange_id or "").strip() - if not needle: - return None - for ex in load_settings().get("exchanges") or []: - if str(ex.get("id") or "").strip() == needle: - return ex - if str(ex.get("key") or "").strip().lower() == needle.lower(): - return ex - return None - - -def list_calculator_exchanges() -> list[dict[str, Any]]: - rows: list[dict[str, Any]] = [] - for ex in enabled_exchanges(): - rows.append( - { - "id": str(ex.get("id") or ""), - "key": str(ex.get("key") or ""), - "name": str(ex.get("name") or ex.get("key") or ""), - "enabled": bool(ex.get("enabled")), - } - ) - return rows - - -def _hub_headers() -> dict[str, str]: - import os - - token = (os.getenv("HUB_BRIDGE_TOKEN") or os.getenv("CONTROL_TOKEN") or "").strip() - if token: - return {"X-Hub-Token": token} - return {} - - -def fetch_instance_market_sync(ex: dict, *, base: str) -> dict[str, Any]: - base_url = (ex.get("flask_url") or "").rstrip("/") - if not base_url: - return {"ok": False, "msg": "未配置 flask_url"} - params = urlencode({"base": normalize_base_symbol(base) or base}) - url = f"{base_url}/api/hub/market?{params}" - req = urllib.request.Request(url, headers=_hub_headers(), method="GET") - try: - with urllib.request.urlopen(req, timeout=HUB_FLASK_TIMEOUT) as resp: - status = int(getattr(resp, "status", 200) or 200) - raw = resp.read().decode("utf-8", errors="replace") - data = json.loads(raw) if raw else {} - if not isinstance(data, dict): - return {"ok": False, "msg": "无效 JSON"} - if status >= 400: - data.setdefault("ok", False) - return data - except urllib.error.HTTPError as exc: - try: - raw = exc.read().decode("utf-8", errors="replace") - body = json.loads(raw) if raw else {} - except Exception: - body = {"ok": False, "msg": raw if "raw" in locals() else str(exc)} - if isinstance(body, dict): - body.setdefault("ok", False) - return body - return {"ok": False, "msg": f"HTTP {exc.code}"} - except Exception as exc: - return {"ok": False, "msg": str(exc)} - - -def _enrich_market_from_settings(ex: dict, payload: dict[str, Any]) -> dict[str, Any]: - out = dict(payload) - out["exchange_id"] = str(ex.get("id") or "") - out["exchange_key"] = str(ex.get("key") or "") - out["exchange_name"] = str(ex.get("name") or ex.get("key") or "") - out["exchange_label"] = out["exchange_name"] - return out - - -def get_calculator_market( - exchange_id: str, - base: str, - *, - ex: dict | None = None, -) -> Tuple[Optional[dict[str, Any]], Optional[str]]: - """从系统设置中的交易实例拉取合约精度(与实盘一致)。""" - row = ex or find_exchange(exchange_id) - if not row: - return None, "未找到该交易所配置" - if not row.get("enabled"): - return None, f"{row.get('name') or exchange_id} 未启用" - - base_u = normalize_base_symbol(base) - if not base_u: - return None, "请输入币种,如 ETH" - - cache_key = f"{row.get('id')}:{base_u}" - now = time.time() - with MARKET_LOCK: - cached = MARKET_CACHE.get(cache_key) - if cached and now - cached[0] < MARKET_TTL_SEC: - return dict(cached[1]), None - - remote = fetch_instance_market_sync(row, base=base_u) - if not remote.get("ok"): - return None, str(remote.get("msg") or "实例返回失败") - - data = _enrich_market_from_settings(row, remote) - with MARKET_LOCK: - MARKET_CACHE[cache_key] = (now, data) - return data, None - - -def clear_market_cache() -> None: - with MARKET_LOCK: - MARKET_CACHE.clear() +"""计算器:从已配置交易实例读取 USDT 永续合约精度与张数规则。""" + +from __future__ import annotations + +import json +import threading +import time +import urllib.error +import urllib.request +from typing import Any, Callable, Optional, Tuple +from urllib.parse import urlencode + +try: + from settings_store import enabled_exchanges, load_settings +except ImportError: + from manual_trading_hub.settings_store import enabled_exchanges, load_settings + +MARKET_CACHE: dict[str, tuple[float, dict[str, Any]]] = {} +MARKET_LOCK = threading.Lock() +MARKET_TTL_SEC = 300.0 +HUB_FLASK_TIMEOUT = float(__import__("os").getenv("HUB_FLASK_TIMEOUT", "20")) + + +def normalize_base_symbol(text: str) -> str: + s = str(text or "").upper().strip() + for suf in ("USDT:USDT", "/USDT:USDT", "/USDT", "USDT", "-USDT-SWAP"): + if s.endswith(suf) and len(s) > len(suf): + s = s[: -len(suf)].strip("-/") + break + if "/" in s: + s = s.split("/", 1)[0].strip() + if ":" in s: + s = s.split(":", 1)[0].strip() + return s + + +def resolve_usdt_perp_symbol(exchange: Any, base: str) -> Tuple[Optional[str], Optional[str]]: + base_u = normalize_base_symbol(base) + if not base_u: + return None, "请输入币种,如 ETH" + candidates = [f"{base_u}/USDT:USDT", f"{base_u}/USDT"] + markets = getattr(exchange, "markets", None) or {} + for sym in candidates: + m = markets.get(sym) + if not m: + continue + if m.get("active") is False: + continue + if m.get("swap") or m.get("linear") or m.get("contract"): + return sym, None + for sym, m in markets.items(): + if m.get("active") is False: + continue + if not (m.get("swap") or m.get("linear")): + continue + if (m.get("quote") or "").upper() != "USDT": + continue + if (m.get("base") or "").upper() == base_u: + return sym, None + return None, f"未找到 {base_u}/USDT 永续合约" + + +def _decimals_from_precision_value(value: Any) -> Optional[int]: + if value in (None, ""): + return None + try: + p = float(value) + except (TypeError, ValueError): + return None + if p >= 1 and abs(p - round(p)) < 1e-9 and p <= 12: + return int(round(p)) + if 0 < p < 1: + s = f"{p:.12f}".rstrip("0") + if "." in s: + return min(12, len(s.split(".", 1)[1])) + return None + + +def _decimals_from_ccxt_str(text: str) -> int: + s = str(text or "").strip() + if not s or "." not in s: + return 0 + frac = s.split(".", 1)[1] + if not frac: + return 0 + return min(12, len(frac.rstrip("0") or frac)) + + +def amount_decimals_from_exchange(exchange: Any, exchange_symbol: str) -> int: + try: + return _decimals_from_ccxt_str(exchange.amount_to_precision(exchange_symbol, 1.23456789)) + except Exception: + market = exchange.market(exchange_symbol) + prec = (market.get("precision") or {}).get("amount") + d = _decimals_from_precision_value(prec) + return d if d is not None else 4 + + +def price_decimals_from_exchange( + exchange: Any, exchange_symbol: str, price_tick: Optional[float] +) -> int: + from lib.hub.hub_ohlcv_lib import normalize_price_tick + + tick = normalize_price_tick(price_tick) + if tick and tick > 0: + if tick >= 1: + return 0 + s = f"{tick:.12f}".rstrip("0") + if "." in s: + return min(12, len(s.split(".", 1)[1])) + try: + return _decimals_from_ccxt_str(exchange.price_to_precision(exchange_symbol, 12345.678901234)) + except Exception: + market = exchange.market(exchange_symbol) + prec = (market.get("precision") or {}).get("price") + d = _decimals_from_precision_value(prec) + return d if d is not None else 4 + + +def make_amount_precise_fn_from_market(market: dict[str, Any]) -> Callable[[float], Optional[float]]: + dec = max(0, int(market.get("amount_decimals") or 4)) + min_amt = market.get("min_amount") + + def _fn(amount: float) -> Optional[float]: + try: + v = float(amount) + except (TypeError, ValueError): + return None + if v <= 0: + return None + factor = 10**dec + v = int(v * factor + 1e-12) / factor + if min_amt is not None: + try: + if v < float(min_amt): + return None + except (TypeError, ValueError): + pass + if v <= 0: + return None + return v + + return _fn + + +def find_exchange(exchange_id: str) -> dict | None: + needle = str(exchange_id or "").strip() + if not needle: + return None + for ex in load_settings().get("exchanges") or []: + if str(ex.get("id") or "").strip() == needle: + return ex + if str(ex.get("key") or "").strip().lower() == needle.lower(): + return ex + return None + + +def list_calculator_exchanges() -> list[dict[str, Any]]: + rows: list[dict[str, Any]] = [] + for ex in enabled_exchanges(): + rows.append( + { + "id": str(ex.get("id") or ""), + "key": str(ex.get("key") or ""), + "name": str(ex.get("name") or ex.get("key") or ""), + "enabled": bool(ex.get("enabled")), + } + ) + return rows + + +def _hub_headers() -> dict[str, str]: + import os + + token = (os.getenv("HUB_BRIDGE_TOKEN") or os.getenv("CONTROL_TOKEN") or "").strip() + if token: + return {"X-Hub-Token": token} + return {} + + +def fetch_instance_market_sync(ex: dict, *, base: str) -> dict[str, Any]: + base_url = (ex.get("flask_url") or "").rstrip("/") + if not base_url: + return {"ok": False, "msg": "未配置 flask_url"} + params = urlencode({"base": normalize_base_symbol(base) or base}) + url = f"{base_url}/api/hub/market?{params}" + req = urllib.request.Request(url, headers=_hub_headers(), method="GET") + try: + with urllib.request.urlopen(req, timeout=HUB_FLASK_TIMEOUT) as resp: + status = int(getattr(resp, "status", 200) or 200) + raw = resp.read().decode("utf-8", errors="replace") + data = json.loads(raw) if raw else {} + if not isinstance(data, dict): + return {"ok": False, "msg": "无效 JSON"} + if status >= 400: + data.setdefault("ok", False) + return data + except urllib.error.HTTPError as exc: + try: + raw = exc.read().decode("utf-8", errors="replace") + body = json.loads(raw) if raw else {} + except Exception: + body = {"ok": False, "msg": raw if "raw" in locals() else str(exc)} + if isinstance(body, dict): + body.setdefault("ok", False) + return body + return {"ok": False, "msg": f"HTTP {exc.code}"} + except Exception as exc: + return {"ok": False, "msg": str(exc)} + + +def _enrich_market_from_settings(ex: dict, payload: dict[str, Any]) -> dict[str, Any]: + out = dict(payload) + out["exchange_id"] = str(ex.get("id") or "") + out["exchange_key"] = str(ex.get("key") or "") + out["exchange_name"] = str(ex.get("name") or ex.get("key") or "") + out["exchange_label"] = out["exchange_name"] + return out + + +def get_calculator_market( + exchange_id: str, + base: str, + *, + ex: dict | None = None, +) -> Tuple[Optional[dict[str, Any]], Optional[str]]: + """从系统设置中的交易实例拉取合约精度(与实盘一致)。""" + row = ex or find_exchange(exchange_id) + if not row: + return None, "未找到该交易所配置" + if not row.get("enabled"): + return None, f"{row.get('name') or exchange_id} 未启用" + + base_u = normalize_base_symbol(base) + if not base_u: + return None, "请输入币种,如 ETH" + + cache_key = f"{row.get('id')}:{base_u}" + now = time.time() + with MARKET_LOCK: + cached = MARKET_CACHE.get(cache_key) + if cached and now - cached[0] < MARKET_TTL_SEC: + return dict(cached[1]), None + + remote = fetch_instance_market_sync(row, base=base_u) + if not remote.get("ok"): + return None, str(remote.get("msg") or "实例返回失败") + + data = _enrich_market_from_settings(row, remote) + with MARKET_LOCK: + MARKET_CACHE[cache_key] = (now, data) + return data, None + + +def clear_market_cache() -> None: + with MARKET_LOCK: + MARKET_CACHE.clear() diff --git a/hub_entry_plan_lib.py b/lib/hub/hub_entry_plan_lib.py similarity index 100% rename from hub_entry_plan_lib.py rename to lib/hub/hub_entry_plan_lib.py diff --git a/hub_fund_history_lib.py b/lib/hub/hub_fund_history_lib.py similarity index 96% rename from hub_fund_history_lib.py rename to lib/hub/hub_fund_history_lib.py index cd037a9..0b1023a 100644 --- a/hub_fund_history_lib.py +++ b/lib/hub/hub_fund_history_lib.py @@ -1,407 +1,407 @@ -"""中控资金概况:分户日快照(180 交易日)、总资金曲线与回撤。""" -from __future__ import annotations - -import json -import os -from datetime import datetime, timedelta -from pathlib import Path -from typing import Any, Optional - -from hub_trades_lib import current_trading_day - -HUB_DIR = Path(__file__).resolve().parent / "manual_trading_hub" -FUND_HISTORY_PATH = HUB_DIR / "hub_fund_history.json" -LEGACY_FUND_HISTORY_PATH = HUB_DIR / "hub_ai_fund_history.json" - -try: - FUND_HISTORY_DAYS = max(30, int(os.getenv("HUB_FUND_HISTORY_DAYS", "180") or "180")) -except ValueError: - FUND_HISTORY_DAYS = 180 - -FUND_HISTORY_START_DAY = (os.getenv("HUB_FUND_HISTORY_START_DAY") or "2026-06-09").strip()[:10] - - -def fund_history_start_day() -> str: - return FUND_HISTORY_START_DAY or "2026-06-09" - - -def _now_str() -> str: - return datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - -def _safe_float(value: Any) -> Optional[float]: - try: - v = float(value) - return v if v >= 0 else None - except (TypeError, ValueError): - return None - - -def account_total_usdt(funding: Any, trading: Any) -> Optional[float]: - """资金户 + 交易户;任一侧缺失则不计入(返回 None)。""" - fu = _safe_float(funding) - tu = _safe_float(trading) - if fu is None or tu is None: - return None - return round(fu + tu, 4) - - -def compute_drawdown(values: list[float]) -> dict[str, Any]: - """基于资金权益序列计算峰值回撤(U 与 %)。""" - peak = 0.0 - max_dd_u = 0.0 - peak_at_end = 0.0 - for v in values: - if not isinstance(v, (int, float)): - continue - fv = float(v) - if fv > peak: - peak = fv - dd = peak - fv - if dd > max_dd_u: - max_dd_u = dd - peak_at_end = peak - max_dd_u = round(max_dd_u, 4) - peak_at_end = round(peak_at_end, 4) - max_dd_pct = round((max_dd_u / peak_at_end) * 100, 2) if peak_at_end > 0 else None - return { - "peak_usdt": peak_at_end, - "max_drawdown_u": max_dd_u, - "max_drawdown_pct": max_dd_pct, - } - - -def _atomic_write(path: Path, data: dict) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - tmp = path.with_suffix(path.suffix + ".tmp") - tmp.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") - os.replace(tmp, path) - - -def _prune_days( - days: dict, - *, - keep_days: int, - anchor_day: str, - start_day: Optional[str] = None, -) -> dict: - try: - anchor = datetime.strptime(anchor_day[:10], "%Y-%m-%d") - except ValueError: - anchor = datetime.now() - rolling_cutoff = (anchor - timedelta(days=max(1, keep_days) - 1)).strftime("%Y-%m-%d") - start = (start_day or fund_history_start_day()).strip()[:10] - cutoff = max(rolling_cutoff, start) if start else rolling_cutoff - return {k: v for k, v in (days or {}).items() if str(k) >= cutoff} - - -def _migrate_legacy_store(days: dict) -> dict: - if not LEGACY_FUND_HISTORY_PATH.is_file(): - return days - try: - loaded = json.loads(LEGACY_FUND_HISTORY_PATH.read_text(encoding="utf-8")) - legacy_days = loaded.get("days") if isinstance(loaded, dict) else {} - if not isinstance(legacy_days, dict): - return days - merged = dict(days) - for day, block in legacy_days.items(): - if day in merged: - continue - if isinstance(block, dict) and block.get("accounts"): - merged[day] = block - return merged - except Exception: - return days - - -def _load_store() -> dict: - if not FUND_HISTORY_PATH.is_file(): - store = {"version": 1, "days": _migrate_legacy_store({})} - if store["days"]: - _atomic_write(FUND_HISTORY_PATH, store) - return store - try: - loaded = json.loads(FUND_HISTORY_PATH.read_text(encoding="utf-8")) - if isinstance(loaded, dict): - loaded.setdefault("version", 1) - days = dict(loaded.get("days") or {}) - loaded["days"] = _migrate_legacy_store(days) - return loaded - except Exception: - pass - return {"version": 1, "days": {}} - - -def record_fund_snapshot( - trading_day: str, - accounts: list[dict], - *, - keep_days: int = FUND_HISTORY_DAYS, - reset_hour: int = 8, -) -> dict[str, Any]: - """写入当日各户资金账户/交易账户余额,并裁剪历史。""" - day = (trading_day or "").strip()[:10] or current_trading_day(reset_hour=reset_hour) - start = fund_history_start_day() - if start and day < start: - return _load_store().get("days") or {} - store = _load_store() - days = dict(store.get("days") or {}) - row_accounts: dict[str, dict] = {} - for ac in accounts or []: - key = str(ac.get("key") or ac.get("id") or "").strip() - if not key: - continue - if not ac.get("monitored"): - continue - fu = _safe_float(ac.get("funding_usdt")) - tu = _safe_float(ac.get("trading_usdt")) - total = account_total_usdt(fu, tu) - if total is None: - continue - row_accounts[key] = { - "name": ac.get("name"), - "funding_usdt": fu, - "trading_usdt": tu, - "total_usdt": total, - "recorded_at": _now_str(), - } - if row_accounts: - days[day] = {"accounts": row_accounts, "updated_at": _now_str()} - days = _prune_days( - days, keep_days=keep_days, anchor_day=day, start_day=fund_history_start_day() - ) - _atomic_write(FUND_HISTORY_PATH, {"version": 1, "days": days}) - return days - - -def record_fund_snapshot_from_board( - rows: list[dict], - *, - keep_days: int = FUND_HISTORY_DAYS, - reset_hour: int = 8, -) -> dict[str, Any]: - """监控板行写入当日快照(仅 account_ok 且资金/交易户齐全)。""" - day = current_trading_day(reset_hour=reset_hour) - accounts = [] - for row in rows or []: - if not isinstance(row, dict): - continue - if not row.get("account_ok"): - continue - accounts.append( - { - "key": row.get("key") or row.get("id"), - "name": row.get("name"), - "funding_usdt": row.get("funding_usdt"), - "trading_usdt": row.get("trading_usdt"), - "monitored": True, - } - ) - return record_fund_snapshot(day, accounts, keep_days=keep_days, reset_hour=reset_hour) - - -def get_fund_history(*, anchor_day: str, keep_days: int = FUND_HISTORY_DAYS) -> dict[str, dict]: - store = _load_store() - return _prune_days( - dict(store.get("days") or {}), - keep_days=keep_days, - anchor_day=anchor_day, - start_day=fund_history_start_day(), - ) - - -def _exchange_monitored(ex: dict) -> bool: - return bool(ex.get("enabled")) and not bool(ex.get("env_disabled")) - - -def _live_row_for_exchange(ex: dict, rows_by_key: dict[str, dict]) -> Optional[dict]: - key = str(ex.get("key") or "").strip() - if not key: - return None - return rows_by_key.get(key) - - -def _series_from_history( - history: dict[str, dict], - account_keys: list[str], -) -> list[dict[str, Any]]: - out: list[dict[str, Any]] = [] - for day in sorted(history.keys()): - block = history.get(day) or {} - ac_map = block.get("accounts") or {} - total = 0.0 - n = 0 - for key in account_keys: - ac = ac_map.get(key) or {} - t = account_total_usdt(ac.get("funding_usdt"), ac.get("trading_usdt")) - if t is None: - t = _safe_float(ac.get("total_usdt")) - if t is None: - continue - total += t - n += 1 - if n > 0: - out.append({"day": day, "total_usdt": round(total, 4)}) - return out - - -def _account_series(history: dict[str, dict], key: str) -> list[dict[str, Any]]: - out: list[dict[str, Any]] = [] - for day in sorted(history.keys()): - ac = (history.get(day) or {}).get("accounts", {}).get(key) or {} - t = account_total_usdt(ac.get("funding_usdt"), ac.get("trading_usdt")) - if t is None: - t = _safe_float(ac.get("total_usdt")) - if t is None: - continue - out.append( - { - "day": day, - "total_usdt": t, - "funding_usdt": _safe_float(ac.get("funding_usdt")), - "trading_usdt": _safe_float(ac.get("trading_usdt")), - } - ) - return out - - -def build_fund_overview( - exchanges: list[dict], - *, - board_rows: Optional[list[dict]] = None, - trading_day: Optional[str] = None, - keep_days: int = FUND_HISTORY_DAYS, - reset_hour: int = 8, - updated_at: Optional[str] = None, -) -> dict[str, Any]: - day = (trading_day or "").strip()[:10] or current_trading_day(reset_hour=reset_hour) - history = get_fund_history(anchor_day=day, keep_days=keep_days) - rows_by_key: dict[str, dict] = {} - for row in board_rows or []: - if isinstance(row, dict): - k = str(row.get("key") or "").strip() - if k: - rows_by_key[k] = row - - monitored_keys: list[str] = [] - accounts_out: list[dict[str, Any]] = [] - live_total = 0.0 - live_known = 0 - - for ex in exchanges or []: - if not _exchange_monitored(ex): - continue - key = str(ex.get("key") or "").strip() - monitored = True - row = _live_row_for_exchange(ex, rows_by_key) - fu = tu = total = None - data_ok = False - if row and row.get("account_ok"): - fu = _safe_float(row.get("funding_usdt")) - tu = _safe_float(row.get("trading_usdt")) - total = account_total_usdt(fu, tu) - data_ok = total is not None - if data_ok: - live_total += total - live_known += 1 - - series = _account_series(history, key) if key else [] - dd = compute_drawdown([p["total_usdt"] for p in series]) if series else { - "peak_usdt": None, - "max_drawdown_u": None, - "max_drawdown_pct": None, - } - day_delta = None - if series: - if len(series) >= 2: - day_delta = round(series[-1]["total_usdt"] - series[-2]["total_usdt"], 4) - elif data_ok and total is not None: - day_delta = round(total - series[-1]["total_usdt"], 4) - - accounts_out.append( - { - "id": ex.get("id"), - "key": key, - "name": ex.get("name") or key, - "monitored": monitored, - "data_ok": data_ok, - "funding_usdt": fu, - "trading_usdt": tu, - "total_usdt": total, - "series": series, - "drawdown": dd, - "day_delta_usdt": day_delta, - } - ) - if key: - monitored_keys.append(key) - - total_series = _series_from_history(history, monitored_keys) - if live_known > 0: - last_day = total_series[-1]["day"] if total_series else None - live_point = round(live_total, 4) - if last_day == day and total_series: - total_series[-1]["total_usdt"] = live_point - total_series[-1]["live"] = True - else: - total_series.append({"day": day, "total_usdt": live_point, "live": True}) - - total_dd = compute_drawdown([p["total_usdt"] for p in total_series]) if total_series else { - "peak_usdt": None, - "max_drawdown_u": None, - "max_drawdown_pct": None, - } - total_day_delta = None - if total_series: - if len(total_series) >= 2: - total_day_delta = round( - total_series[-1]["total_usdt"] - total_series[-2]["total_usdt"], 4 - ) - - return { - "ok": True, - "trading_day": day, - "reset_hour": reset_hour, - "keep_days": keep_days, - "history_start_day": fund_history_start_day(), - "updated_at": updated_at, - "totals": { - "monitored_count": len(monitored_keys), - "live_known_count": live_known, - "total_usdt": round(live_total, 4) if live_known > 0 else None, - "day_delta_usdt": total_day_delta, - "series": total_series, - "drawdown": total_dd, - }, - "accounts": accounts_out, - } - - -def format_fund_history_text( - history: dict[str, dict], - *, - account_names: Optional[dict[str, str]] = None, -) -> str: - if not history: - return "(暂无资金历史快照)" - names = account_names or {} - lines = ["【资金快照(资金账户 + 交易账户 USDT)】"] - for day in sorted(history.keys()): - block = history.get(day) or {} - ac_map = block.get("accounts") or {} - if not ac_map: - continue - parts = [] - for key, ac in ac_map.items(): - label = names.get(key) or ac.get("name") or key - fu = ac.get("funding_usdt") - tu = ac.get("trading_usdt") - tot = ac.get("total_usdt") - if tot is None: - tot = account_total_usdt(fu, tu) - fu_txt = f"{fu}U" if fu is not None else "未知" - tu_txt = f"{tu}U" if tu is not None else "未知" - tot_txt = f"{tot}U" if tot is not None else "未知" - parts.append(f"{label}: 合计{tot_txt}(资金{fu_txt}/交易{tu_txt})") - lines.append(f"- {day}: " + ";".join(parts)) - return "\n".join(lines) if len(lines) > 1 else "(暂无资金历史快照)" +"""中控资金概况:分户日快照(180 交易日)、总资金曲线与回撤。""" +from __future__ import annotations + +import json +import os +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, Optional + +from lib.hub.hub_trades_lib import current_trading_day + +HUB_DIR = Path(__file__).resolve().parent / "manual_trading_hub" +FUND_HISTORY_PATH = HUB_DIR / "hub_fund_history.json" +LEGACY_FUND_HISTORY_PATH = HUB_DIR / "hub_ai_fund_history.json" + +try: + FUND_HISTORY_DAYS = max(30, int(os.getenv("HUB_FUND_HISTORY_DAYS", "180") or "180")) +except ValueError: + FUND_HISTORY_DAYS = 180 + +FUND_HISTORY_START_DAY = (os.getenv("HUB_FUND_HISTORY_START_DAY") or "2026-06-09").strip()[:10] + + +def fund_history_start_day() -> str: + return FUND_HISTORY_START_DAY or "2026-06-09" + + +def _now_str() -> str: + return datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def _safe_float(value: Any) -> Optional[float]: + try: + v = float(value) + return v if v >= 0 else None + except (TypeError, ValueError): + return None + + +def account_total_usdt(funding: Any, trading: Any) -> Optional[float]: + """资金户 + 交易户;任一侧缺失则不计入(返回 None)。""" + fu = _safe_float(funding) + tu = _safe_float(trading) + if fu is None or tu is None: + return None + return round(fu + tu, 4) + + +def compute_drawdown(values: list[float]) -> dict[str, Any]: + """基于资金权益序列计算峰值回撤(U 与 %)。""" + peak = 0.0 + max_dd_u = 0.0 + peak_at_end = 0.0 + for v in values: + if not isinstance(v, (int, float)): + continue + fv = float(v) + if fv > peak: + peak = fv + dd = peak - fv + if dd > max_dd_u: + max_dd_u = dd + peak_at_end = peak + max_dd_u = round(max_dd_u, 4) + peak_at_end = round(peak_at_end, 4) + max_dd_pct = round((max_dd_u / peak_at_end) * 100, 2) if peak_at_end > 0 else None + return { + "peak_usdt": peak_at_end, + "max_drawdown_u": max_dd_u, + "max_drawdown_pct": max_dd_pct, + } + + +def _atomic_write(path: Path, data: dict) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + tmp.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp, path) + + +def _prune_days( + days: dict, + *, + keep_days: int, + anchor_day: str, + start_day: Optional[str] = None, +) -> dict: + try: + anchor = datetime.strptime(anchor_day[:10], "%Y-%m-%d") + except ValueError: + anchor = datetime.now() + rolling_cutoff = (anchor - timedelta(days=max(1, keep_days) - 1)).strftime("%Y-%m-%d") + start = (start_day or fund_history_start_day()).strip()[:10] + cutoff = max(rolling_cutoff, start) if start else rolling_cutoff + return {k: v for k, v in (days or {}).items() if str(k) >= cutoff} + + +def _migrate_legacy_store(days: dict) -> dict: + if not LEGACY_FUND_HISTORY_PATH.is_file(): + return days + try: + loaded = json.loads(LEGACY_FUND_HISTORY_PATH.read_text(encoding="utf-8")) + legacy_days = loaded.get("days") if isinstance(loaded, dict) else {} + if not isinstance(legacy_days, dict): + return days + merged = dict(days) + for day, block in legacy_days.items(): + if day in merged: + continue + if isinstance(block, dict) and block.get("accounts"): + merged[day] = block + return merged + except Exception: + return days + + +def _load_store() -> dict: + if not FUND_HISTORY_PATH.is_file(): + store = {"version": 1, "days": _migrate_legacy_store({})} + if store["days"]: + _atomic_write(FUND_HISTORY_PATH, store) + return store + try: + loaded = json.loads(FUND_HISTORY_PATH.read_text(encoding="utf-8")) + if isinstance(loaded, dict): + loaded.setdefault("version", 1) + days = dict(loaded.get("days") or {}) + loaded["days"] = _migrate_legacy_store(days) + return loaded + except Exception: + pass + return {"version": 1, "days": {}} + + +def record_fund_snapshot( + trading_day: str, + accounts: list[dict], + *, + keep_days: int = FUND_HISTORY_DAYS, + reset_hour: int = 8, +) -> dict[str, Any]: + """写入当日各户资金账户/交易账户余额,并裁剪历史。""" + day = (trading_day or "").strip()[:10] or current_trading_day(reset_hour=reset_hour) + start = fund_history_start_day() + if start and day < start: + return _load_store().get("days") or {} + store = _load_store() + days = dict(store.get("days") or {}) + row_accounts: dict[str, dict] = {} + for ac in accounts or []: + key = str(ac.get("key") or ac.get("id") or "").strip() + if not key: + continue + if not ac.get("monitored"): + continue + fu = _safe_float(ac.get("funding_usdt")) + tu = _safe_float(ac.get("trading_usdt")) + total = account_total_usdt(fu, tu) + if total is None: + continue + row_accounts[key] = { + "name": ac.get("name"), + "funding_usdt": fu, + "trading_usdt": tu, + "total_usdt": total, + "recorded_at": _now_str(), + } + if row_accounts: + days[day] = {"accounts": row_accounts, "updated_at": _now_str()} + days = _prune_days( + days, keep_days=keep_days, anchor_day=day, start_day=fund_history_start_day() + ) + _atomic_write(FUND_HISTORY_PATH, {"version": 1, "days": days}) + return days + + +def record_fund_snapshot_from_board( + rows: list[dict], + *, + keep_days: int = FUND_HISTORY_DAYS, + reset_hour: int = 8, +) -> dict[str, Any]: + """监控板行写入当日快照(仅 account_ok 且资金/交易户齐全)。""" + day = current_trading_day(reset_hour=reset_hour) + accounts = [] + for row in rows or []: + if not isinstance(row, dict): + continue + if not row.get("account_ok"): + continue + accounts.append( + { + "key": row.get("key") or row.get("id"), + "name": row.get("name"), + "funding_usdt": row.get("funding_usdt"), + "trading_usdt": row.get("trading_usdt"), + "monitored": True, + } + ) + return record_fund_snapshot(day, accounts, keep_days=keep_days, reset_hour=reset_hour) + + +def get_fund_history(*, anchor_day: str, keep_days: int = FUND_HISTORY_DAYS) -> dict[str, dict]: + store = _load_store() + return _prune_days( + dict(store.get("days") or {}), + keep_days=keep_days, + anchor_day=anchor_day, + start_day=fund_history_start_day(), + ) + + +def _exchange_monitored(ex: dict) -> bool: + return bool(ex.get("enabled")) and not bool(ex.get("env_disabled")) + + +def _live_row_for_exchange(ex: dict, rows_by_key: dict[str, dict]) -> Optional[dict]: + key = str(ex.get("key") or "").strip() + if not key: + return None + return rows_by_key.get(key) + + +def _series_from_history( + history: dict[str, dict], + account_keys: list[str], +) -> list[dict[str, Any]]: + out: list[dict[str, Any]] = [] + for day in sorted(history.keys()): + block = history.get(day) or {} + ac_map = block.get("accounts") or {} + total = 0.0 + n = 0 + for key in account_keys: + ac = ac_map.get(key) or {} + t = account_total_usdt(ac.get("funding_usdt"), ac.get("trading_usdt")) + if t is None: + t = _safe_float(ac.get("total_usdt")) + if t is None: + continue + total += t + n += 1 + if n > 0: + out.append({"day": day, "total_usdt": round(total, 4)}) + return out + + +def _account_series(history: dict[str, dict], key: str) -> list[dict[str, Any]]: + out: list[dict[str, Any]] = [] + for day in sorted(history.keys()): + ac = (history.get(day) or {}).get("accounts", {}).get(key) or {} + t = account_total_usdt(ac.get("funding_usdt"), ac.get("trading_usdt")) + if t is None: + t = _safe_float(ac.get("total_usdt")) + if t is None: + continue + out.append( + { + "day": day, + "total_usdt": t, + "funding_usdt": _safe_float(ac.get("funding_usdt")), + "trading_usdt": _safe_float(ac.get("trading_usdt")), + } + ) + return out + + +def build_fund_overview( + exchanges: list[dict], + *, + board_rows: Optional[list[dict]] = None, + trading_day: Optional[str] = None, + keep_days: int = FUND_HISTORY_DAYS, + reset_hour: int = 8, + updated_at: Optional[str] = None, +) -> dict[str, Any]: + day = (trading_day or "").strip()[:10] or current_trading_day(reset_hour=reset_hour) + history = get_fund_history(anchor_day=day, keep_days=keep_days) + rows_by_key: dict[str, dict] = {} + for row in board_rows or []: + if isinstance(row, dict): + k = str(row.get("key") or "").strip() + if k: + rows_by_key[k] = row + + monitored_keys: list[str] = [] + accounts_out: list[dict[str, Any]] = [] + live_total = 0.0 + live_known = 0 + + for ex in exchanges or []: + if not _exchange_monitored(ex): + continue + key = str(ex.get("key") or "").strip() + monitored = True + row = _live_row_for_exchange(ex, rows_by_key) + fu = tu = total = None + data_ok = False + if row and row.get("account_ok"): + fu = _safe_float(row.get("funding_usdt")) + tu = _safe_float(row.get("trading_usdt")) + total = account_total_usdt(fu, tu) + data_ok = total is not None + if data_ok: + live_total += total + live_known += 1 + + series = _account_series(history, key) if key else [] + dd = compute_drawdown([p["total_usdt"] for p in series]) if series else { + "peak_usdt": None, + "max_drawdown_u": None, + "max_drawdown_pct": None, + } + day_delta = None + if series: + if len(series) >= 2: + day_delta = round(series[-1]["total_usdt"] - series[-2]["total_usdt"], 4) + elif data_ok and total is not None: + day_delta = round(total - series[-1]["total_usdt"], 4) + + accounts_out.append( + { + "id": ex.get("id"), + "key": key, + "name": ex.get("name") or key, + "monitored": monitored, + "data_ok": data_ok, + "funding_usdt": fu, + "trading_usdt": tu, + "total_usdt": total, + "series": series, + "drawdown": dd, + "day_delta_usdt": day_delta, + } + ) + if key: + monitored_keys.append(key) + + total_series = _series_from_history(history, monitored_keys) + if live_known > 0: + last_day = total_series[-1]["day"] if total_series else None + live_point = round(live_total, 4) + if last_day == day and total_series: + total_series[-1]["total_usdt"] = live_point + total_series[-1]["live"] = True + else: + total_series.append({"day": day, "total_usdt": live_point, "live": True}) + + total_dd = compute_drawdown([p["total_usdt"] for p in total_series]) if total_series else { + "peak_usdt": None, + "max_drawdown_u": None, + "max_drawdown_pct": None, + } + total_day_delta = None + if total_series: + if len(total_series) >= 2: + total_day_delta = round( + total_series[-1]["total_usdt"] - total_series[-2]["total_usdt"], 4 + ) + + return { + "ok": True, + "trading_day": day, + "reset_hour": reset_hour, + "keep_days": keep_days, + "history_start_day": fund_history_start_day(), + "updated_at": updated_at, + "totals": { + "monitored_count": len(monitored_keys), + "live_known_count": live_known, + "total_usdt": round(live_total, 4) if live_known > 0 else None, + "day_delta_usdt": total_day_delta, + "series": total_series, + "drawdown": total_dd, + }, + "accounts": accounts_out, + } + + +def format_fund_history_text( + history: dict[str, dict], + *, + account_names: Optional[dict[str, str]] = None, +) -> str: + if not history: + return "(暂无资金历史快照)" + names = account_names or {} + lines = ["【资金快照(资金账户 + 交易账户 USDT)】"] + for day in sorted(history.keys()): + block = history.get(day) or {} + ac_map = block.get("accounts") or {} + if not ac_map: + continue + parts = [] + for key, ac in ac_map.items(): + label = names.get(key) or ac.get("name") or key + fu = ac.get("funding_usdt") + tu = ac.get("trading_usdt") + tot = ac.get("total_usdt") + if tot is None: + tot = account_total_usdt(fu, tu) + fu_txt = f"{fu}U" if fu is not None else "未知" + tu_txt = f"{tu}U" if tu is not None else "未知" + tot_txt = f"{tot}U" if tot is not None else "未知" + parts.append(f"{label}: 合计{tot_txt}(资金{fu_txt}/交易{tu_txt})") + lines.append(f"- {day}: " + ";".join(parts)) + return "\n".join(lines) if len(lines) > 1 else "(暂无资金历史快照)" diff --git a/hub_host_status_lib.py b/lib/hub/hub_host_status_lib.py similarity index 100% rename from hub_host_status_lib.py rename to lib/hub/hub_host_status_lib.py diff --git a/hub_kline_store.py b/lib/hub/hub_kline_store.py similarity index 96% rename from hub_kline_store.py rename to lib/hub/hub_kline_store.py index 054a970..4dd736e 100644 --- a/hub_kline_store.py +++ b/lib/hub/hub_kline_store.py @@ -1,881 +1,881 @@ -"""中控 K 线 SQLite:分周期保留、交易所直拉、分页读取。""" - -from __future__ import annotations - -import os -import sqlite3 -import time -from pathlib import Path -from typing import Any, Callable, Optional - -from hub_ohlcv_lib import ( - HUB_KLINE_1M_MAX_BARS, - HUB_KLINE_5M_1H_RETENTION_DAYS, - TIMEFRAME_MS, - YEAR_ROLLING_STORED, - chart_chunk_limit, - chart_initial_limit, - chart_memory_cap, - history_cutoff_ms_for_storage, - normalize_chart_timeframe, - normalize_price_tick, - format_price_by_tick, - last_closed_bar_open_ms, - retention_policy_meta, - round_ohlcv_bars_to_tick, - seed_bar_target, -) - -HUB_KLINE_MIN_BARS_BEFORE_TAIL = 200 -HUB_KLINE_REMOTE_FETCH_CAP = 1500 - -_DEFAULT_RETENTION_DAYS = 15 - - -def retention_days() -> int: - """兼容旧配置;新策略见 retention_policy_meta。""" - try: - return max(1, int(os.getenv("HUB_KLINE_RETENTION_DAYS", str(_DEFAULT_RETENTION_DAYS)))) - except ValueError: - return _DEFAULT_RETENTION_DAYS - - -def default_db_path() -> Path: - raw = (os.getenv("HUB_KLINE_DB_PATH") or "").strip() - if raw: - return Path(raw) - hub_dir = Path(__file__).resolve().parent / "manual_trading_hub" / "data" - hub_dir.mkdir(parents=True, exist_ok=True) - return hub_dir / "hub_kline.db" - - -def _connect(db_path: Path | None = None) -> sqlite3.Connection: - path = db_path or default_db_path() - path.parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(str(path), timeout=30, isolation_level=None) - conn.row_factory = sqlite3.Row - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA synchronous=NORMAL") - return conn - - -def init_db(db_path: Path | None = None) -> None: - conn = _connect(db_path) - try: - conn.execute( - """ - CREATE TABLE IF NOT EXISTS ohlcv_bars ( - exchange_key TEXT NOT NULL, - symbol TEXT NOT NULL, - timeframe TEXT NOT NULL, - open_time_ms INTEGER NOT NULL, - open REAL NOT NULL, - high REAL NOT NULL, - low REAL NOT NULL, - close REAL NOT NULL, - volume REAL NOT NULL DEFAULT 0, - updated_at INTEGER NOT NULL, - PRIMARY KEY (exchange_key, symbol, timeframe, open_time_ms) - ) - """ - ) - conn.execute( - """ - CREATE INDEX IF NOT EXISTS idx_ohlcv_series - ON ohlcv_bars (exchange_key, symbol, timeframe, open_time_ms) - """ - ) - conn.execute( - """ - CREATE TABLE IF NOT EXISTS ohlcv_symbol_meta ( - exchange_key TEXT NOT NULL, - symbol TEXT NOT NULL, - price_tick REAL, - updated_at INTEGER NOT NULL, - PRIMARY KEY (exchange_key, symbol) - ) - """ - ) - finally: - conn.close() - - -def save_symbol_price_tick( - exchange_key: str, - symbol: str, - price_tick: float | None, - db_path: Path | None = None, -) -> None: - tick = price_tick - if tick is None: - return - try: - t = float(tick) - except (TypeError, ValueError): - return - if t <= 0: - return - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - conn = _connect(db_path) - try: - conn.execute( - """ - INSERT INTO ohlcv_symbol_meta (exchange_key, symbol, price_tick, updated_at) - VALUES (?,?,?,?) - ON CONFLICT(exchange_key, symbol) DO UPDATE SET - price_tick=excluded.price_tick, - updated_at=excluded.updated_at - """, - (ex_k, sym, t, int(time.time())), - ) - finally: - conn.close() - - -def load_symbol_price_tick( - exchange_key: str, - symbol: str, - db_path: Path | None = None, -) -> float | None: - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - conn = _connect(db_path) - try: - row = conn.execute( - "SELECT price_tick FROM ohlcv_symbol_meta WHERE exchange_key=? AND symbol=?", - (ex_k, sym), - ).fetchone() - if not row or row["price_tick"] is None: - return None - return float(row["price_tick"]) - except (TypeError, ValueError): - return None - finally: - conn.close() - - -def purge_timeframe_by_days( - timeframe: str, - days: int, - db_path: Path | None = None, -) -> int: - cutoff = int(time.time() * 1000) - max(1, int(days)) * 86400000 - tf = normalize_chart_timeframe(timeframe) - conn = _connect(db_path) - try: - cur = conn.execute( - "DELETE FROM ohlcv_bars WHERE timeframe=? AND open_time_ms < ?", - (tf, cutoff), - ) - return int(cur.rowcount or 0) - finally: - conn.close() - - -def purge_1m_bar_cap(db_path: Path | None = None, *, max_bars: int | None = None) -> int: - cap = max(100, int(max_bars or HUB_KLINE_1M_MAX_BARS)) - conn = _connect(db_path) - try: - cur = conn.execute( - """ - DELETE FROM ohlcv_bars - WHERE timeframe='1m' AND rowid IN ( - SELECT rowid FROM ( - SELECT rowid, - ROW_NUMBER() OVER ( - PARTITION BY exchange_key, symbol - ORDER BY open_time_ms DESC - ) AS rn - FROM ohlcv_bars - WHERE timeframe='1m' - ) WHERE rn > ? - ) - """, - (cap,), - ) - return int(cur.rowcount or 0) - finally: - conn.close() - - -def clear_series_bars( - exchange_key: str, - symbol: str, - timeframe: str | None = None, - db_path: Path | None = None, -) -> int: - """删除某交易所+币种 K 线(可指定周期);用于清库后全量重拉。""" - init_db(db_path) - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - if not ex_k or not sym: - return 0 - conn = _connect(db_path) - try: - if timeframe: - tf = normalize_chart_timeframe(timeframe) - cur = conn.execute( - "DELETE FROM ohlcv_bars WHERE exchange_key=? AND symbol=? AND timeframe=?", - (ex_k, sym, tf), - ) - else: - cur = conn.execute( - "DELETE FROM ohlcv_bars WHERE exchange_key=? AND symbol=?", - (ex_k, sym), - ) - return int(cur.rowcount or 0) - finally: - conn.close() - - -def clear_all_bars(db_path: Path | None = None) -> int: - """清空 hub K 线库全部 OHLCV 行。""" - init_db(db_path) - conn = _connect(db_path) - try: - cur = conn.execute("DELETE FROM ohlcv_bars") - return int(cur.rowcount or 0) - finally: - conn.close() - - -def purge_retention(db_path: Path | None = None) -> int: - """按周期策略清理:5m/15m/1h/2h/4h 一年;1m 保留最近 N 根;1d/1w 不删。""" - n = 0 - for tf in sorted(YEAR_ROLLING_STORED): - n += purge_timeframe_by_days(tf, HUB_KLINE_5M_1H_RETENTION_DAYS, db_path) - n += purge_1m_bar_cap(db_path) - return n - - -def upsert_bars( - exchange_key: str, - symbol: str, - timeframe: str, - bars: list[dict[str, Any]], - db_path: Path | None = None, -) -> int: - if not bars: - return 0 - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - tf = normalize_chart_timeframe(timeframe) - now = int(time.time()) - conn = _connect(db_path) - n = 0 - try: - for b in bars: - try: - oms = int(b["open_time_ms"]) - conn.execute( - """ - INSERT INTO ohlcv_bars - (exchange_key, symbol, timeframe, open_time_ms, open, high, low, close, volume, updated_at) - VALUES (?,?,?,?,?,?,?,?,?,?) - ON CONFLICT(exchange_key, symbol, timeframe, open_time_ms) DO UPDATE SET - open=excluded.open, - high=excluded.high, - low=excluded.low, - close=excluded.close, - volume=excluded.volume, - updated_at=excluded.updated_at - """, - ( - ex_k, - sym, - tf, - oms, - float(b["open"]), - float(b["high"]), - float(b["low"]), - float(b["close"]), - float(b.get("volume") or 0), - now, - ), - ) - n += 1 - except (KeyError, TypeError, ValueError): - continue - finally: - conn.close() - return n - - -def load_bars_range( - exchange_key: str, - symbol: str, - timeframe: str, - start_ms: int, - end_ms: int, - db_path: Path | None = None, -) -> list[dict[str, Any]]: - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - tf = normalize_chart_timeframe(timeframe) - conn = _connect(db_path) - try: - rows = conn.execute( - """ - SELECT open_time_ms, open, high, low, close, volume - FROM ohlcv_bars - WHERE exchange_key=? AND symbol=? AND timeframe=? - AND open_time_ms >= ? AND open_time_ms <= ? - ORDER BY open_time_ms ASC - """, - (ex_k, sym, tf, int(start_ms), int(end_ms)), - ).fetchall() - return _rows_to_bars(rows) - finally: - conn.close() - - -def count_series_bars( - exchange_key: str, - symbol: str, - timeframe: str, - db_path: Path | None = None, -) -> int: - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - tf = normalize_chart_timeframe(timeframe) - conn = _connect(db_path) - try: - row = conn.execute( - """ - SELECT COUNT(*) AS c FROM ohlcv_bars - WHERE exchange_key=? AND symbol=? AND timeframe=? - """, - (ex_k, sym, tf), - ).fetchone() - return int(row["c"] or 0) if row else 0 - finally: - conn.close() - - -def _remote_fetch_limit( - *, - need: int, - force_refresh: bool, - storage_tf: str, - tail_only: bool, -) -> int: - if tail_only: - return min(need + 20, 300) - cap = HUB_KLINE_REMOTE_FETCH_CAP - if force_refresh: - return min(seed_bar_target(storage_tf), cap) - return min(max(need + 20, 1), cap) - - -def _since_ms_for_span( - *, - now_ms: int, - period_ms: int, - span_bars: int, - cutoff_ms: int, -) -> int: - """拉取窗口起点:跨度必须与 fetch_limit 一致,保证数据能铺到最近。""" - span = max(1, int(span_bars)) - return max(int(cutoff_ms), int(now_ms) - int(period_ms) * span) - - -def load_bars_latest( - exchange_key: str, - symbol: str, - timeframe: str, - limit: int, - db_path: Path | None = None, -) -> list[dict[str, Any]]: - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - tf = normalize_chart_timeframe(timeframe) - lim = max(1, int(limit)) - conn = _connect(db_path) - try: - rows = conn.execute( - """ - SELECT open_time_ms, open, high, low, close, volume - FROM ohlcv_bars - WHERE exchange_key=? AND symbol=? AND timeframe=? - ORDER BY open_time_ms DESC - LIMIT ? - """, - (ex_k, sym, tf, lim), - ).fetchall() - return list(reversed(_rows_to_bars(rows))) - finally: - conn.close() - - -def load_bars_before( - exchange_key: str, - symbol: str, - timeframe: str, - before_ms: int, - limit: int, - db_path: Path | None = None, -) -> list[dict[str, Any]]: - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - tf = normalize_chart_timeframe(timeframe) - lim = max(1, int(limit)) - bms = int(before_ms) - conn = _connect(db_path) - try: - rows = conn.execute( - """ - SELECT open_time_ms, open, high, low, close, volume - FROM ohlcv_bars - WHERE exchange_key=? AND symbol=? AND timeframe=? - AND open_time_ms < ? - ORDER BY open_time_ms DESC - LIMIT ? - """, - (ex_k, sym, tf, bms, lim), - ).fetchall() - return list(reversed(_rows_to_bars(rows))) - finally: - conn.close() - - -def trim_contiguous_tail( - bars: list[dict[str, Any]], - period_ms: int, - *, - max_gap_factor: float = 3.0, -) -> tuple[list[dict[str, Any]], int]: - """只保留最近一段连续 K 线,丢弃左侧与主段断开的孤立数据。""" - if len(bars) <= 1: - return list(bars), 0 - try: - period = max(1, int(period_ms)) - except (TypeError, ValueError): - period = 60_000 - max_gap = int(period * max_gap_factor) - split = 0 - for i in range(len(bars) - 1, 0, -1): - gap = int(bars[i]["open_time_ms"]) - int(bars[i - 1]["open_time_ms"]) - if gap > max_gap: - split = i - break - return bars[split:], split - - -def normalize_contiguous_db_rows( - bars: list[dict[str, Any]], - *, - period_ms: int, - exchange_key: str, - symbol: str, - timeframe: str, - db_path: Path | None = None, - purge_orphans: bool = True, -) -> list[dict[str, Any]]: - """去掉与主段断开的孤立前缀;可选同步清理库内孤立数据。""" - if len(bars) <= 1: - return list(bars) - trimmed, split_at = trim_contiguous_tail(bars, period_ms) - if split_at > 0 and purge_orphans: - purge_bars_open_before( - exchange_key, - symbol, - timeframe, - int(trimmed[0]["open_time_ms"]), - db_path, - ) - return trimmed - - -def purge_bars_open_before( - exchange_key: str, - symbol: str, - timeframe: str, - open_time_ms: int, - db_path: Path | None = None, -) -> int: - """删除某品种周期下早于 open_time_ms 的 K 线(清理与主段断开的孤立历史)。""" - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - tf = normalize_chart_timeframe(timeframe) - conn = _connect(db_path) - try: - cur = conn.execute( - """ - DELETE FROM ohlcv_bars - WHERE exchange_key=? AND symbol=? AND timeframe=? AND open_time_ms < ? - """, - (ex_k, sym, tf, int(open_time_ms)), - ) - return int(cur.rowcount or 0) - finally: - conn.close() - - -def _rows_to_bars(rows) -> list[dict[str, Any]]: - return [ - { - "open_time_ms": int(r["open_time_ms"]), - "open": float(r["open"]), - "high": float(r["high"]), - "low": float(r["low"]), - "close": float(r["close"]), - "volume": float(r["volume"] or 0), - } - for r in rows - ] - - -def _to_chart_candles(bars: list[dict[str, Any]]) -> list[dict[str, Any]]: - out = [] - for b in bars: - try: - out.append( - { - "time": int(b["open_time_ms"] // 1000), - "open": float(b["open"]), - "high": float(b["high"]), - "low": float(b["low"]), - "close": float(b["close"]), - "volume": float(b.get("volume") or 0), - } - ) - except (KeyError, TypeError, ValueError): - continue - return out - - -def _trim_display_bars( - bars: list[dict[str, Any]], - *, - need: int, - before_ms: int | None, -) -> list[dict[str, Any]]: - if not bars: - return [] - if before_ms is not None and int(before_ms) > 0: - bms = int(before_ms) - bars = [b for b in bars if int(b["open_time_ms"]) < bms] - if len(bars) > need: - bars = bars[-need:] - return bars - if len(bars) > need: - bars = bars[-need:] - return bars - - -def resolve_chart_bars( - exchange_key: str, - symbol: str, - timeframe: str, - remote_fetch: Callable[..., dict[str, Any]], - *, - db_path: Path | None = None, - force_refresh: bool = False, - tail_refresh: bool = False, - clear_db: bool = False, - limit: int | None = None, - before_ms: int | None = None, -) -> dict[str, Any]: - """ - 分页读库:首屏 / 左拖 before_ms / 尾部 tail_refresh。 - 各展示周期均直读交易所同步入库的同名 K 线。 - """ - init_db(db_path) - purged = purge_retention(db_path) - cleared = 0 - - sym = (symbol or "").strip().upper() - ex_k = (exchange_key or "").strip().lower() - display_tf = normalize_chart_timeframe(timeframe) - if not sym or not ex_k: - return {"ok": False, "msg": "缺少 exchange 或 symbol"} - - storage_tf = display_tf - is_history = before_ms is not None and int(before_ms) > 0 - need = int( - limit - or (chart_chunk_limit(display_tf) if is_history else chart_initial_limit(display_tf)) - ) - need = max(1, min(need, chart_memory_cap(display_tf))) - - now_ms = int(time.time() * 1000) - period_display = TIMEFRAME_MS[display_tf] - period_storage = TIMEFRAME_MS[storage_tf] - series_bar_count = ( - count_series_bars(ex_k, sym, storage_tf, db_path) if not is_history else 0 - ) - if tail_refresh and not is_history: - min_seed = min(chart_initial_limit(display_tf) // 5, HUB_KLINE_MIN_BARS_BEFORE_TAIL) - if series_bar_count < max(1, min_seed): - tail_refresh = False - else: - need = min(need, 30) - cutoff = history_cutoff_ms_for_storage(storage_tf, now_ms) - - if clear_db and not is_history and not tail_refresh: - cleared = clear_series_bars(ex_k, sym, storage_tf, db_path) - - def load_display_rows() -> list[dict[str, Any]]: - if is_history: - rows = load_bars_before(ex_k, sym, storage_tf, int(before_ms), need, db_path) - return _trim_display_bars(rows, need=need, before_ms=int(before_ms)) - return load_bars_latest(ex_k, sym, storage_tf, need, db_path) - - db_rows: list[dict[str, Any]] = [] - if not force_refresh: - db_rows = load_display_rows() - if not is_history and db_rows: - db_rows = normalize_contiguous_db_rows( - db_rows, - period_ms=period_display, - exchange_key=ex_k, - symbol=sym, - timeframe=storage_tf, - db_path=db_path, - ) - - last_closed = last_closed_bar_open_ms(display_tf, now_ms) - newest_db = db_rows[-1]["open_time_ms"] if db_rows else None - if is_history: - newest_ok = True - else: - newest_ok = newest_db is not None and int(newest_db) >= int(last_closed) - period_display - - need_fetch = force_refresh or ( - not is_history and (len(db_rows) < need or not newest_ok) - ) - if is_history and len(db_rows) < need: - need_fetch = True - - tail_only = False - if tail_refresh and not is_history and db_rows and not force_refresh and not need_fetch: - need_fetch = True - tail_only = True - - fetched = 0 - price_tick: Optional[float] = None - remote_err: Optional[str] = None - - if need_fetch: - if is_history: - bms = int(before_ms) - anchor = bms - period_display - since = max(cutoff, anchor - period_storage * need) - fetch_limit = min(need + 20, 1500) - elif tail_only: - anchor_ms = int(newest_db) if newest_db is not None else now_ms - fetch_limit = _remote_fetch_limit( - need=need, force_refresh=False, storage_tf=storage_tf, tail_only=True - ) - since = _since_ms_for_span( - now_ms=anchor_ms, - period_ms=period_storage, - span_bars=5, - cutoff_ms=cutoff, - ) - else: - fetch_limit = _remote_fetch_limit( - need=need, - force_refresh=force_refresh, - storage_tf=storage_tf, - tail_only=False, - ) - since = _since_ms_for_span( - now_ms=now_ms, - period_ms=period_storage, - span_bars=fetch_limit, - cutoff_ms=cutoff, - ) - - remote = remote_fetch( - symbol=sym, - timeframe=storage_tf, - since_ms=since, - limit=fetch_limit, - ) - if remote.get("ok") and remote.get("bars"): - fetched = upsert_bars(ex_k, sym, storage_tf, remote["bars"], db_path) - price_tick = remote.get("price_tick") - if price_tick is not None: - save_symbol_price_tick(ex_k, sym, price_tick, db_path) - db_rows = load_display_rows() - if not is_history and db_rows: - db_rows = normalize_contiguous_db_rows( - db_rows, - period_ms=period_display, - exchange_key=ex_k, - symbol=sym, - timeframe=storage_tf, - db_path=db_path, - ) - if not is_history and not tail_only and db_rows: - newest_ms = int(db_rows[-1]["open_time_ms"]) - if newest_ms < int(last_closed) - period_display: - gap_limit = min( - 500, - int((now_ms - newest_ms) // period_storage) + 10, - ) - if gap_limit > 1: - gap_remote = remote_fetch( - symbol=sym, - timeframe=storage_tf, - since_ms=newest_ms, - limit=gap_limit, - ) - if gap_remote.get("ok") and gap_remote.get("bars"): - fetched += upsert_bars( - ex_k, sym, storage_tf, gap_remote["bars"], db_path - ) - db_rows = load_display_rows() - db_rows = normalize_contiguous_db_rows( - db_rows, - period_ms=period_display, - exchange_key=ex_k, - symbol=sym, - timeframe=storage_tf, - db_path=db_path, - ) - else: - remote_err = remote.get("msg") or remote.get("error") or "实例拉取 K 线失败" - if not db_rows: - if is_history: - exhausted = True - else: - return {"ok": False, "msg": remote_err, "purged": purged} - - exhausted = False - if is_history: - if not db_rows: - exhausted = True - elif len(db_rows) < need: - oldest = int(db_rows[0]["open_time_ms"]) - if cutoff > 0 and oldest <= cutoff + period_storage: - exhausted = True - elif fetched == 0: - exhausted = True - - if price_tick is None: - price_tick = load_symbol_price_tick(ex_k, sym, db_path) - if price_tick is None and not is_history: - try: - tick_probe = remote_fetch( - symbol=sym, - timeframe=storage_tf, - since_ms=None, - limit=3, - ) - if tick_probe.get("ok"): - price_tick = tick_probe.get("price_tick") - if price_tick is not None: - save_symbol_price_tick(ex_k, sym, price_tick, db_path) - except Exception: - pass - - if not is_history and db_rows: - db_rows = normalize_contiguous_db_rows( - db_rows, - period_ms=period_display, - exchange_key=ex_k, - symbol=sym, - timeframe=storage_tf, - db_path=db_path, - ) - - if not is_history and len(db_rows) < need: - missing = need - len(db_rows) - backfill_limit = min(missing + 60, HUB_KLINE_REMOTE_FETCH_CAP) - if db_rows: - oldest = int(db_rows[0]["open_time_ms"]) - backfill_since = _since_ms_for_span( - now_ms=oldest, - period_ms=period_storage, - span_bars=backfill_limit, - cutoff_ms=cutoff, - ) - else: - backfill_since = _since_ms_for_span( - now_ms=now_ms, - period_ms=period_storage, - span_bars=backfill_limit, - cutoff_ms=cutoff, - ) - try: - remote_back = remote_fetch( - symbol=sym, - timeframe=storage_tf, - since_ms=backfill_since, - limit=backfill_limit, - ) - if remote_back.get("ok") and remote_back.get("bars"): - fetched += upsert_bars(ex_k, sym, storage_tf, remote_back["bars"], db_path) - if remote_back.get("price_tick") is not None: - price_tick = remote_back.get("price_tick") - save_symbol_price_tick(ex_k, sym, price_tick, db_path) - db_rows = load_display_rows() - db_rows = normalize_contiguous_db_rows( - db_rows, - period_ms=period_display, - exchange_key=ex_k, - symbol=sym, - timeframe=storage_tf, - db_path=db_path, - ) - elif not remote_err: - remote_err = ( - remote_back.get("msg") - or remote_back.get("error") - or "实例补拉 K 线失败" - ) - except Exception as e: - if not remote_err: - remote_err = str(e) - - price_tick = normalize_price_tick(price_tick) - if db_rows and price_tick is not None: - round_ohlcv_bars_to_tick(db_rows, price_tick) - - candles = _to_chart_candles(db_rows) - if not is_history and not candles and not exhausted: - return {"ok": False, "msg": remote_err or "无 K 线数据", "purged": purged} - - oldest_ms = int(db_rows[0]["open_time_ms"]) if db_rows else None - newest_ms = int(db_rows[-1]["open_time_ms"]) if db_rows else None - - from_cache = max(0, len(candles) - min(fetched, len(candles))) if fetched else len(candles) - - return { - "ok": True, - "symbol": sym, - "exchange_key": ex_k, - "timeframe": display_tf, - "storage_timeframe": storage_tf, - "limit": need, - "before_ms": int(before_ms) if is_history else None, - "oldest_ms": oldest_ms, - "newest_ms": newest_ms, - "exhausted": exhausted, - "source": "remote" if fetched else "db", - "retention_policy": retention_policy_meta(), - "candles": candles, - "from_cache": from_cache, - "fetched": fetched, - "cleared": cleared, - "purged": purged, - "price_tick": price_tick, - "stale": bool(remote_err), - "stale_message": remote_err if remote_err else None, - "updated_at": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), - } - - -def format_ohlcv_detail(bar: dict[str, Any] | None, tick: Optional[float]) -> dict[str, str]: - if not bar: - return {"open": "-", "high": "-", "low": "-", "close": "-", "volume": "-"} - return { - "open": format_price_by_tick(bar.get("open"), tick), - "high": format_price_by_tick(bar.get("high"), tick), - "low": format_price_by_tick(bar.get("low"), tick), - "close": format_price_by_tick(bar.get("close"), tick), - "volume": format_price_by_tick(bar.get("volume"), tick), - } +"""中控 K 线 SQLite:分周期保留、交易所直拉、分页读取。""" + +from __future__ import annotations + +import os +import sqlite3 +import time +from pathlib import Path +from typing import Any, Callable, Optional + +from lib.hub.hub_ohlcv_lib import ( + HUB_KLINE_1M_MAX_BARS, + HUB_KLINE_5M_1H_RETENTION_DAYS, + TIMEFRAME_MS, + YEAR_ROLLING_STORED, + chart_chunk_limit, + chart_initial_limit, + chart_memory_cap, + history_cutoff_ms_for_storage, + normalize_chart_timeframe, + normalize_price_tick, + format_price_by_tick, + last_closed_bar_open_ms, + retention_policy_meta, + round_ohlcv_bars_to_tick, + seed_bar_target, +) + +HUB_KLINE_MIN_BARS_BEFORE_TAIL = 200 +HUB_KLINE_REMOTE_FETCH_CAP = 1500 + +_DEFAULT_RETENTION_DAYS = 15 + + +def retention_days() -> int: + """兼容旧配置;新策略见 retention_policy_meta。""" + try: + return max(1, int(os.getenv("HUB_KLINE_RETENTION_DAYS", str(_DEFAULT_RETENTION_DAYS)))) + except ValueError: + return _DEFAULT_RETENTION_DAYS + + +def default_db_path() -> Path: + raw = (os.getenv("HUB_KLINE_DB_PATH") or "").strip() + if raw: + return Path(raw) + hub_dir = Path(__file__).resolve().parent / "manual_trading_hub" / "data" + hub_dir.mkdir(parents=True, exist_ok=True) + return hub_dir / "hub_kline.db" + + +def _connect(db_path: Path | None = None) -> sqlite3.Connection: + path = db_path or default_db_path() + path.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(str(path), timeout=30, isolation_level=None) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + return conn + + +def init_db(db_path: Path | None = None) -> None: + conn = _connect(db_path) + try: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS ohlcv_bars ( + exchange_key TEXT NOT NULL, + symbol TEXT NOT NULL, + timeframe TEXT NOT NULL, + open_time_ms INTEGER NOT NULL, + open REAL NOT NULL, + high REAL NOT NULL, + low REAL NOT NULL, + close REAL NOT NULL, + volume REAL NOT NULL DEFAULT 0, + updated_at INTEGER NOT NULL, + PRIMARY KEY (exchange_key, symbol, timeframe, open_time_ms) + ) + """ + ) + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_ohlcv_series + ON ohlcv_bars (exchange_key, symbol, timeframe, open_time_ms) + """ + ) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS ohlcv_symbol_meta ( + exchange_key TEXT NOT NULL, + symbol TEXT NOT NULL, + price_tick REAL, + updated_at INTEGER NOT NULL, + PRIMARY KEY (exchange_key, symbol) + ) + """ + ) + finally: + conn.close() + + +def save_symbol_price_tick( + exchange_key: str, + symbol: str, + price_tick: float | None, + db_path: Path | None = None, +) -> None: + tick = price_tick + if tick is None: + return + try: + t = float(tick) + except (TypeError, ValueError): + return + if t <= 0: + return + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + conn = _connect(db_path) + try: + conn.execute( + """ + INSERT INTO ohlcv_symbol_meta (exchange_key, symbol, price_tick, updated_at) + VALUES (?,?,?,?) + ON CONFLICT(exchange_key, symbol) DO UPDATE SET + price_tick=excluded.price_tick, + updated_at=excluded.updated_at + """, + (ex_k, sym, t, int(time.time())), + ) + finally: + conn.close() + + +def load_symbol_price_tick( + exchange_key: str, + symbol: str, + db_path: Path | None = None, +) -> float | None: + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + conn = _connect(db_path) + try: + row = conn.execute( + "SELECT price_tick FROM ohlcv_symbol_meta WHERE exchange_key=? AND symbol=?", + (ex_k, sym), + ).fetchone() + if not row or row["price_tick"] is None: + return None + return float(row["price_tick"]) + except (TypeError, ValueError): + return None + finally: + conn.close() + + +def purge_timeframe_by_days( + timeframe: str, + days: int, + db_path: Path | None = None, +) -> int: + cutoff = int(time.time() * 1000) - max(1, int(days)) * 86400000 + tf = normalize_chart_timeframe(timeframe) + conn = _connect(db_path) + try: + cur = conn.execute( + "DELETE FROM ohlcv_bars WHERE timeframe=? AND open_time_ms < ?", + (tf, cutoff), + ) + return int(cur.rowcount or 0) + finally: + conn.close() + + +def purge_1m_bar_cap(db_path: Path | None = None, *, max_bars: int | None = None) -> int: + cap = max(100, int(max_bars or HUB_KLINE_1M_MAX_BARS)) + conn = _connect(db_path) + try: + cur = conn.execute( + """ + DELETE FROM ohlcv_bars + WHERE timeframe='1m' AND rowid IN ( + SELECT rowid FROM ( + SELECT rowid, + ROW_NUMBER() OVER ( + PARTITION BY exchange_key, symbol + ORDER BY open_time_ms DESC + ) AS rn + FROM ohlcv_bars + WHERE timeframe='1m' + ) WHERE rn > ? + ) + """, + (cap,), + ) + return int(cur.rowcount or 0) + finally: + conn.close() + + +def clear_series_bars( + exchange_key: str, + symbol: str, + timeframe: str | None = None, + db_path: Path | None = None, +) -> int: + """删除某交易所+币种 K 线(可指定周期);用于清库后全量重拉。""" + init_db(db_path) + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + if not ex_k or not sym: + return 0 + conn = _connect(db_path) + try: + if timeframe: + tf = normalize_chart_timeframe(timeframe) + cur = conn.execute( + "DELETE FROM ohlcv_bars WHERE exchange_key=? AND symbol=? AND timeframe=?", + (ex_k, sym, tf), + ) + else: + cur = conn.execute( + "DELETE FROM ohlcv_bars WHERE exchange_key=? AND symbol=?", + (ex_k, sym), + ) + return int(cur.rowcount or 0) + finally: + conn.close() + + +def clear_all_bars(db_path: Path | None = None) -> int: + """清空 hub K 线库全部 OHLCV 行。""" + init_db(db_path) + conn = _connect(db_path) + try: + cur = conn.execute("DELETE FROM ohlcv_bars") + return int(cur.rowcount or 0) + finally: + conn.close() + + +def purge_retention(db_path: Path | None = None) -> int: + """按周期策略清理:5m/15m/1h/2h/4h 一年;1m 保留最近 N 根;1d/1w 不删。""" + n = 0 + for tf in sorted(YEAR_ROLLING_STORED): + n += purge_timeframe_by_days(tf, HUB_KLINE_5M_1H_RETENTION_DAYS, db_path) + n += purge_1m_bar_cap(db_path) + return n + + +def upsert_bars( + exchange_key: str, + symbol: str, + timeframe: str, + bars: list[dict[str, Any]], + db_path: Path | None = None, +) -> int: + if not bars: + return 0 + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + tf = normalize_chart_timeframe(timeframe) + now = int(time.time()) + conn = _connect(db_path) + n = 0 + try: + for b in bars: + try: + oms = int(b["open_time_ms"]) + conn.execute( + """ + INSERT INTO ohlcv_bars + (exchange_key, symbol, timeframe, open_time_ms, open, high, low, close, volume, updated_at) + VALUES (?,?,?,?,?,?,?,?,?,?) + ON CONFLICT(exchange_key, symbol, timeframe, open_time_ms) DO UPDATE SET + open=excluded.open, + high=excluded.high, + low=excluded.low, + close=excluded.close, + volume=excluded.volume, + updated_at=excluded.updated_at + """, + ( + ex_k, + sym, + tf, + oms, + float(b["open"]), + float(b["high"]), + float(b["low"]), + float(b["close"]), + float(b.get("volume") or 0), + now, + ), + ) + n += 1 + except (KeyError, TypeError, ValueError): + continue + finally: + conn.close() + return n + + +def load_bars_range( + exchange_key: str, + symbol: str, + timeframe: str, + start_ms: int, + end_ms: int, + db_path: Path | None = None, +) -> list[dict[str, Any]]: + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + tf = normalize_chart_timeframe(timeframe) + conn = _connect(db_path) + try: + rows = conn.execute( + """ + SELECT open_time_ms, open, high, low, close, volume + FROM ohlcv_bars + WHERE exchange_key=? AND symbol=? AND timeframe=? + AND open_time_ms >= ? AND open_time_ms <= ? + ORDER BY open_time_ms ASC + """, + (ex_k, sym, tf, int(start_ms), int(end_ms)), + ).fetchall() + return _rows_to_bars(rows) + finally: + conn.close() + + +def count_series_bars( + exchange_key: str, + symbol: str, + timeframe: str, + db_path: Path | None = None, +) -> int: + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + tf = normalize_chart_timeframe(timeframe) + conn = _connect(db_path) + try: + row = conn.execute( + """ + SELECT COUNT(*) AS c FROM ohlcv_bars + WHERE exchange_key=? AND symbol=? AND timeframe=? + """, + (ex_k, sym, tf), + ).fetchone() + return int(row["c"] or 0) if row else 0 + finally: + conn.close() + + +def _remote_fetch_limit( + *, + need: int, + force_refresh: bool, + storage_tf: str, + tail_only: bool, +) -> int: + if tail_only: + return min(need + 20, 300) + cap = HUB_KLINE_REMOTE_FETCH_CAP + if force_refresh: + return min(seed_bar_target(storage_tf), cap) + return min(max(need + 20, 1), cap) + + +def _since_ms_for_span( + *, + now_ms: int, + period_ms: int, + span_bars: int, + cutoff_ms: int, +) -> int: + """拉取窗口起点:跨度必须与 fetch_limit 一致,保证数据能铺到最近。""" + span = max(1, int(span_bars)) + return max(int(cutoff_ms), int(now_ms) - int(period_ms) * span) + + +def load_bars_latest( + exchange_key: str, + symbol: str, + timeframe: str, + limit: int, + db_path: Path | None = None, +) -> list[dict[str, Any]]: + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + tf = normalize_chart_timeframe(timeframe) + lim = max(1, int(limit)) + conn = _connect(db_path) + try: + rows = conn.execute( + """ + SELECT open_time_ms, open, high, low, close, volume + FROM ohlcv_bars + WHERE exchange_key=? AND symbol=? AND timeframe=? + ORDER BY open_time_ms DESC + LIMIT ? + """, + (ex_k, sym, tf, lim), + ).fetchall() + return list(reversed(_rows_to_bars(rows))) + finally: + conn.close() + + +def load_bars_before( + exchange_key: str, + symbol: str, + timeframe: str, + before_ms: int, + limit: int, + db_path: Path | None = None, +) -> list[dict[str, Any]]: + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + tf = normalize_chart_timeframe(timeframe) + lim = max(1, int(limit)) + bms = int(before_ms) + conn = _connect(db_path) + try: + rows = conn.execute( + """ + SELECT open_time_ms, open, high, low, close, volume + FROM ohlcv_bars + WHERE exchange_key=? AND symbol=? AND timeframe=? + AND open_time_ms < ? + ORDER BY open_time_ms DESC + LIMIT ? + """, + (ex_k, sym, tf, bms, lim), + ).fetchall() + return list(reversed(_rows_to_bars(rows))) + finally: + conn.close() + + +def trim_contiguous_tail( + bars: list[dict[str, Any]], + period_ms: int, + *, + max_gap_factor: float = 3.0, +) -> tuple[list[dict[str, Any]], int]: + """只保留最近一段连续 K 线,丢弃左侧与主段断开的孤立数据。""" + if len(bars) <= 1: + return list(bars), 0 + try: + period = max(1, int(period_ms)) + except (TypeError, ValueError): + period = 60_000 + max_gap = int(period * max_gap_factor) + split = 0 + for i in range(len(bars) - 1, 0, -1): + gap = int(bars[i]["open_time_ms"]) - int(bars[i - 1]["open_time_ms"]) + if gap > max_gap: + split = i + break + return bars[split:], split + + +def normalize_contiguous_db_rows( + bars: list[dict[str, Any]], + *, + period_ms: int, + exchange_key: str, + symbol: str, + timeframe: str, + db_path: Path | None = None, + purge_orphans: bool = True, +) -> list[dict[str, Any]]: + """去掉与主段断开的孤立前缀;可选同步清理库内孤立数据。""" + if len(bars) <= 1: + return list(bars) + trimmed, split_at = trim_contiguous_tail(bars, period_ms) + if split_at > 0 and purge_orphans: + purge_bars_open_before( + exchange_key, + symbol, + timeframe, + int(trimmed[0]["open_time_ms"]), + db_path, + ) + return trimmed + + +def purge_bars_open_before( + exchange_key: str, + symbol: str, + timeframe: str, + open_time_ms: int, + db_path: Path | None = None, +) -> int: + """删除某品种周期下早于 open_time_ms 的 K 线(清理与主段断开的孤立历史)。""" + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + tf = normalize_chart_timeframe(timeframe) + conn = _connect(db_path) + try: + cur = conn.execute( + """ + DELETE FROM ohlcv_bars + WHERE exchange_key=? AND symbol=? AND timeframe=? AND open_time_ms < ? + """, + (ex_k, sym, tf, int(open_time_ms)), + ) + return int(cur.rowcount or 0) + finally: + conn.close() + + +def _rows_to_bars(rows) -> list[dict[str, Any]]: + return [ + { + "open_time_ms": int(r["open_time_ms"]), + "open": float(r["open"]), + "high": float(r["high"]), + "low": float(r["low"]), + "close": float(r["close"]), + "volume": float(r["volume"] or 0), + } + for r in rows + ] + + +def _to_chart_candles(bars: list[dict[str, Any]]) -> list[dict[str, Any]]: + out = [] + for b in bars: + try: + out.append( + { + "time": int(b["open_time_ms"] // 1000), + "open": float(b["open"]), + "high": float(b["high"]), + "low": float(b["low"]), + "close": float(b["close"]), + "volume": float(b.get("volume") or 0), + } + ) + except (KeyError, TypeError, ValueError): + continue + return out + + +def _trim_display_bars( + bars: list[dict[str, Any]], + *, + need: int, + before_ms: int | None, +) -> list[dict[str, Any]]: + if not bars: + return [] + if before_ms is not None and int(before_ms) > 0: + bms = int(before_ms) + bars = [b for b in bars if int(b["open_time_ms"]) < bms] + if len(bars) > need: + bars = bars[-need:] + return bars + if len(bars) > need: + bars = bars[-need:] + return bars + + +def resolve_chart_bars( + exchange_key: str, + symbol: str, + timeframe: str, + remote_fetch: Callable[..., dict[str, Any]], + *, + db_path: Path | None = None, + force_refresh: bool = False, + tail_refresh: bool = False, + clear_db: bool = False, + limit: int | None = None, + before_ms: int | None = None, +) -> dict[str, Any]: + """ + 分页读库:首屏 / 左拖 before_ms / 尾部 tail_refresh。 + 各展示周期均直读交易所同步入库的同名 K 线。 + """ + init_db(db_path) + purged = purge_retention(db_path) + cleared = 0 + + sym = (symbol or "").strip().upper() + ex_k = (exchange_key or "").strip().lower() + display_tf = normalize_chart_timeframe(timeframe) + if not sym or not ex_k: + return {"ok": False, "msg": "缺少 exchange 或 symbol"} + + storage_tf = display_tf + is_history = before_ms is not None and int(before_ms) > 0 + need = int( + limit + or (chart_chunk_limit(display_tf) if is_history else chart_initial_limit(display_tf)) + ) + need = max(1, min(need, chart_memory_cap(display_tf))) + + now_ms = int(time.time() * 1000) + period_display = TIMEFRAME_MS[display_tf] + period_storage = TIMEFRAME_MS[storage_tf] + series_bar_count = ( + count_series_bars(ex_k, sym, storage_tf, db_path) if not is_history else 0 + ) + if tail_refresh and not is_history: + min_seed = min(chart_initial_limit(display_tf) // 5, HUB_KLINE_MIN_BARS_BEFORE_TAIL) + if series_bar_count < max(1, min_seed): + tail_refresh = False + else: + need = min(need, 30) + cutoff = history_cutoff_ms_for_storage(storage_tf, now_ms) + + if clear_db and not is_history and not tail_refresh: + cleared = clear_series_bars(ex_k, sym, storage_tf, db_path) + + def load_display_rows() -> list[dict[str, Any]]: + if is_history: + rows = load_bars_before(ex_k, sym, storage_tf, int(before_ms), need, db_path) + return _trim_display_bars(rows, need=need, before_ms=int(before_ms)) + return load_bars_latest(ex_k, sym, storage_tf, need, db_path) + + db_rows: list[dict[str, Any]] = [] + if not force_refresh: + db_rows = load_display_rows() + if not is_history and db_rows: + db_rows = normalize_contiguous_db_rows( + db_rows, + period_ms=period_display, + exchange_key=ex_k, + symbol=sym, + timeframe=storage_tf, + db_path=db_path, + ) + + last_closed = last_closed_bar_open_ms(display_tf, now_ms) + newest_db = db_rows[-1]["open_time_ms"] if db_rows else None + if is_history: + newest_ok = True + else: + newest_ok = newest_db is not None and int(newest_db) >= int(last_closed) - period_display + + need_fetch = force_refresh or ( + not is_history and (len(db_rows) < need or not newest_ok) + ) + if is_history and len(db_rows) < need: + need_fetch = True + + tail_only = False + if tail_refresh and not is_history and db_rows and not force_refresh and not need_fetch: + need_fetch = True + tail_only = True + + fetched = 0 + price_tick: Optional[float] = None + remote_err: Optional[str] = None + + if need_fetch: + if is_history: + bms = int(before_ms) + anchor = bms - period_display + since = max(cutoff, anchor - period_storage * need) + fetch_limit = min(need + 20, 1500) + elif tail_only: + anchor_ms = int(newest_db) if newest_db is not None else now_ms + fetch_limit = _remote_fetch_limit( + need=need, force_refresh=False, storage_tf=storage_tf, tail_only=True + ) + since = _since_ms_for_span( + now_ms=anchor_ms, + period_ms=period_storage, + span_bars=5, + cutoff_ms=cutoff, + ) + else: + fetch_limit = _remote_fetch_limit( + need=need, + force_refresh=force_refresh, + storage_tf=storage_tf, + tail_only=False, + ) + since = _since_ms_for_span( + now_ms=now_ms, + period_ms=period_storage, + span_bars=fetch_limit, + cutoff_ms=cutoff, + ) + + remote = remote_fetch( + symbol=sym, + timeframe=storage_tf, + since_ms=since, + limit=fetch_limit, + ) + if remote.get("ok") and remote.get("bars"): + fetched = upsert_bars(ex_k, sym, storage_tf, remote["bars"], db_path) + price_tick = remote.get("price_tick") + if price_tick is not None: + save_symbol_price_tick(ex_k, sym, price_tick, db_path) + db_rows = load_display_rows() + if not is_history and db_rows: + db_rows = normalize_contiguous_db_rows( + db_rows, + period_ms=period_display, + exchange_key=ex_k, + symbol=sym, + timeframe=storage_tf, + db_path=db_path, + ) + if not is_history and not tail_only and db_rows: + newest_ms = int(db_rows[-1]["open_time_ms"]) + if newest_ms < int(last_closed) - period_display: + gap_limit = min( + 500, + int((now_ms - newest_ms) // period_storage) + 10, + ) + if gap_limit > 1: + gap_remote = remote_fetch( + symbol=sym, + timeframe=storage_tf, + since_ms=newest_ms, + limit=gap_limit, + ) + if gap_remote.get("ok") and gap_remote.get("bars"): + fetched += upsert_bars( + ex_k, sym, storage_tf, gap_remote["bars"], db_path + ) + db_rows = load_display_rows() + db_rows = normalize_contiguous_db_rows( + db_rows, + period_ms=period_display, + exchange_key=ex_k, + symbol=sym, + timeframe=storage_tf, + db_path=db_path, + ) + else: + remote_err = remote.get("msg") or remote.get("error") or "实例拉取 K 线失败" + if not db_rows: + if is_history: + exhausted = True + else: + return {"ok": False, "msg": remote_err, "purged": purged} + + exhausted = False + if is_history: + if not db_rows: + exhausted = True + elif len(db_rows) < need: + oldest = int(db_rows[0]["open_time_ms"]) + if cutoff > 0 and oldest <= cutoff + period_storage: + exhausted = True + elif fetched == 0: + exhausted = True + + if price_tick is None: + price_tick = load_symbol_price_tick(ex_k, sym, db_path) + if price_tick is None and not is_history: + try: + tick_probe = remote_fetch( + symbol=sym, + timeframe=storage_tf, + since_ms=None, + limit=3, + ) + if tick_probe.get("ok"): + price_tick = tick_probe.get("price_tick") + if price_tick is not None: + save_symbol_price_tick(ex_k, sym, price_tick, db_path) + except Exception: + pass + + if not is_history and db_rows: + db_rows = normalize_contiguous_db_rows( + db_rows, + period_ms=period_display, + exchange_key=ex_k, + symbol=sym, + timeframe=storage_tf, + db_path=db_path, + ) + + if not is_history and len(db_rows) < need: + missing = need - len(db_rows) + backfill_limit = min(missing + 60, HUB_KLINE_REMOTE_FETCH_CAP) + if db_rows: + oldest = int(db_rows[0]["open_time_ms"]) + backfill_since = _since_ms_for_span( + now_ms=oldest, + period_ms=period_storage, + span_bars=backfill_limit, + cutoff_ms=cutoff, + ) + else: + backfill_since = _since_ms_for_span( + now_ms=now_ms, + period_ms=period_storage, + span_bars=backfill_limit, + cutoff_ms=cutoff, + ) + try: + remote_back = remote_fetch( + symbol=sym, + timeframe=storage_tf, + since_ms=backfill_since, + limit=backfill_limit, + ) + if remote_back.get("ok") and remote_back.get("bars"): + fetched += upsert_bars(ex_k, sym, storage_tf, remote_back["bars"], db_path) + if remote_back.get("price_tick") is not None: + price_tick = remote_back.get("price_tick") + save_symbol_price_tick(ex_k, sym, price_tick, db_path) + db_rows = load_display_rows() + db_rows = normalize_contiguous_db_rows( + db_rows, + period_ms=period_display, + exchange_key=ex_k, + symbol=sym, + timeframe=storage_tf, + db_path=db_path, + ) + elif not remote_err: + remote_err = ( + remote_back.get("msg") + or remote_back.get("error") + or "实例补拉 K 线失败" + ) + except Exception as e: + if not remote_err: + remote_err = str(e) + + price_tick = normalize_price_tick(price_tick) + if db_rows and price_tick is not None: + round_ohlcv_bars_to_tick(db_rows, price_tick) + + candles = _to_chart_candles(db_rows) + if not is_history and not candles and not exhausted: + return {"ok": False, "msg": remote_err or "无 K 线数据", "purged": purged} + + oldest_ms = int(db_rows[0]["open_time_ms"]) if db_rows else None + newest_ms = int(db_rows[-1]["open_time_ms"]) if db_rows else None + + from_cache = max(0, len(candles) - min(fetched, len(candles))) if fetched else len(candles) + + return { + "ok": True, + "symbol": sym, + "exchange_key": ex_k, + "timeframe": display_tf, + "storage_timeframe": storage_tf, + "limit": need, + "before_ms": int(before_ms) if is_history else None, + "oldest_ms": oldest_ms, + "newest_ms": newest_ms, + "exhausted": exhausted, + "source": "remote" if fetched else "db", + "retention_policy": retention_policy_meta(), + "candles": candles, + "from_cache": from_cache, + "fetched": fetched, + "cleared": cleared, + "purged": purged, + "price_tick": price_tick, + "stale": bool(remote_err), + "stale_message": remote_err if remote_err else None, + "updated_at": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), + } + + +def format_ohlcv_detail(bar: dict[str, Any] | None, tick: Optional[float]) -> dict[str, str]: + if not bar: + return {"open": "-", "high": "-", "low": "-", "close": "-", "volume": "-"} + return { + "open": format_price_by_tick(bar.get("open"), tick), + "high": format_price_by_tick(bar.get("high"), tick), + "low": format_price_by_tick(bar.get("low"), tick), + "close": format_price_by_tick(bar.get("close"), tick), + "volume": format_price_by_tick(bar.get("volume"), tick), + } diff --git a/hub_macro_calendar_lib.py b/lib/hub/hub_macro_calendar_lib.py similarity index 96% rename from hub_macro_calendar_lib.py rename to lib/hub/hub_macro_calendar_lib.py index 45aba3e..10ed915 100644 --- a/hub_macro_calendar_lib.py +++ b/lib/hub/hub_macro_calendar_lib.py @@ -1,311 +1,311 @@ -"""中控宏观关键数据日历:手动录入 FOMC / CPI / 非农档发布时间,±1h 风控前置窗口。""" - -from __future__ import annotations - -import os -import sqlite3 -import time -from datetime import datetime -from pathlib import Path -from typing import Any -from zoneinfo import ZoneInfo - -from hub_symbol_archive_lib import parse_wall_clock_ms - -DISPLAY_TZ = ZoneInfo(os.getenv("APP_TIMEZONE", "Asia/Shanghai")) - -MACRO_EVENT_TYPES = ("fomc", "cpi", "employment") - -MACRO_EVENT_LABELS: dict[str, str] = { - "fomc": "FOMC 联邦基金利率", - "cpi": "美国 CPI 通胀", - "employment": "就业与劳工数据", -} - -WINDOW_BEFORE_MS = int(os.getenv("HUB_MACRO_WINDOW_BEFORE_SEC", str(3600))) * 1000 -WINDOW_AFTER_MS = int(os.getenv("HUB_MACRO_WINDOW_AFTER_SEC", str(3600))) * 1000 -IMMINENT_BEFORE_MS = int(os.getenv("HUB_MACRO_IMMINENT_BEFORE_SEC", str(1800))) * 1000 -LIST_FUTURE_DAYS = int(os.getenv("HUB_MACRO_LIST_FUTURE_DAYS", "60")) - - -def default_db_path() -> Path: - raw = (os.getenv("HUB_MACRO_CALENDAR_DB_PATH") or "").strip() - if raw: - return Path(raw) - hub_dir = Path(__file__).resolve().parent / "manual_trading_hub" / "data" - hub_dir.mkdir(parents=True, exist_ok=True) - return hub_dir / "hub_macro_calendar.db" - - -def _connect(db_path: Path | None = None) -> sqlite3.Connection: - path = db_path or default_db_path() - path.parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(str(path), timeout=30, isolation_level=None) - conn.row_factory = sqlite3.Row - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA synchronous=NORMAL") - return conn - - -def init_db(db_path: Path | None = None) -> None: - conn = _connect(db_path) - try: - conn.execute( - """ - CREATE TABLE IF NOT EXISTS macro_events ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - event_type TEXT NOT NULL, - event_at_ms INTEGER NOT NULL, - note TEXT NOT NULL DEFAULT '', - created_at_ms INTEGER NOT NULL, - updated_at_ms INTEGER NOT NULL - ) - """ - ) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_macro_events_at ON macro_events(event_at_ms)" - ) - finally: - conn.close() - - -def normalize_event_type(raw: str) -> str: - key = (raw or "").strip().lower() - if key not in MACRO_EVENT_TYPES: - raise ValueError(f"事件类型须为: {', '.join(MACRO_EVENT_LABELS.values())}") - return key - - -def parse_event_at_ms(raw: Any) -> int: - ms = parse_wall_clock_ms(raw, tz=DISPLAY_TZ) - if ms is None: - raise ValueError("发布时间格式错误,请使用 YYYY-MM-DD HH:MM 或 YYYY-MM-DDTHH:MM") - return int(ms) - - -def format_event_at(ms: int) -> str: - dt = datetime.fromtimestamp(ms / 1000, tz=DISPLAY_TZ) - return dt.strftime("%Y-%m-%d %H:%M") - - -def _row_to_dict(row: sqlite3.Row) -> dict[str, Any]: - ms = int(row["event_at_ms"]) - et = str(row["event_type"]) - return { - "id": int(row["id"]), - "event_type": et, - "event_type_label": MACRO_EVENT_LABELS.get(et, et), - "event_at_ms": ms, - "event_at": format_event_at(ms), - "note": str(row["note"] or ""), - "created_at_ms": int(row["created_at_ms"]), - "updated_at_ms": int(row["updated_at_ms"]), - } - - -def _window_bounds(event_at_ms: int) -> tuple[int, int]: - start = int(event_at_ms) - WINDOW_BEFORE_MS - end = int(event_at_ms) + WINDOW_AFTER_MS - return start, end - - -def enrich_alert(row: dict[str, Any], now_ms: int | None = None) -> dict[str, Any] | None: - now = int(now_ms if now_ms is not None else time.time() * 1000) - event_at_ms = int(row["event_at_ms"]) - window_start, window_end = _window_bounds(event_at_ms) - if now < window_start or now > window_end: - return None - imminent = now >= (event_at_ms - IMMINENT_BEFORE_MS) and now <= window_end - mins_to_event = max(0, int((event_at_ms - now) / 60000)) - mins_from_event = max(0, int((now - event_at_ms) / 60000)) - return { - **row, - "window_start_ms": window_start, - "window_end_ms": window_end, - "window_start": format_event_at(window_start), - "window_end": format_event_at(window_end), - "phase": "imminent" if imminent else "window", - "phase_label": "即将发布" if imminent and now < event_at_ms else "高波动窗口", - "minutes_to_event": mins_to_event if now < event_at_ms else 0, - "minutes_from_event": mins_from_event if now >= event_at_ms else 0, - } - - -def list_events( - *, - now_ms: int | None = None, - include_expired_hours: int = 24, - db_path: Path | None = None, -) -> list[dict[str, Any]]: - init_db(db_path) - now = int(now_ms if now_ms is not None else time.time() * 1000) - horizon = now + LIST_FUTURE_DAYS * 86400 * 1000 - expired_cutoff = now - max(0, int(include_expired_hours)) * 3600 * 1000 - WINDOW_AFTER_MS - conn = _connect(db_path) - try: - rows = conn.execute( - """ - SELECT * FROM macro_events - WHERE event_at_ms >= ? AND event_at_ms <= ? - ORDER BY event_at_ms ASC, id ASC - """, - (expired_cutoff, horizon), - ).fetchall() - return [_row_to_dict(r) for r in rows] - finally: - conn.close() - - -def get_event(event_id: int, db_path: Path | None = None) -> dict[str, Any] | None: - init_db(db_path) - conn = _connect(db_path) - try: - row = conn.execute("SELECT * FROM macro_events WHERE id=?", (int(event_id),)).fetchone() - return _row_to_dict(row) if row else None - finally: - conn.close() - - -def _assert_no_duplicate( - conn: sqlite3.Connection, - event_type: str, - event_at_ms: int, - *, - exclude_id: int | None = None, -) -> None: - if exclude_id is None: - row = conn.execute( - "SELECT id FROM macro_events WHERE event_type=? AND event_at_ms=? LIMIT 1", - (event_type, int(event_at_ms)), - ).fetchone() - else: - row = conn.execute( - """ - SELECT id FROM macro_events - WHERE event_type=? AND event_at_ms=? AND id<>? - LIMIT 1 - """, - (event_type, int(event_at_ms), int(exclude_id)), - ).fetchone() - if row: - raise ValueError("同类型、同发布时间的记录已存在") - - -def create_event( - event_type: str, - event_at: Any, - *, - note: str = "", - db_path: Path | None = None, -) -> dict[str, Any]: - init_db(db_path) - et = normalize_event_type(event_type) - event_at_ms = parse_event_at_ms(event_at) - note_s = str(note or "").strip()[:500] - now_ms = int(time.time() * 1000) - conn = _connect(db_path) - try: - _assert_no_duplicate(conn, et, event_at_ms) - cur = conn.execute( - """ - INSERT INTO macro_events (event_type, event_at_ms, note, created_at_ms, updated_at_ms) - VALUES (?, ?, ?, ?, ?) - """, - (et, event_at_ms, note_s, now_ms, now_ms), - ) - eid = int(cur.lastrowid) - finally: - conn.close() - row = get_event(eid, db_path=db_path) - assert row is not None - return row - - -def update_event( - event_id: int, - *, - event_type: str | None = None, - event_at: Any | None = None, - note: str | None = None, - db_path: Path | None = None, -) -> dict[str, Any] | None: - init_db(db_path) - existing = get_event(event_id, db_path=db_path) - if not existing: - return None - et = normalize_event_type(event_type if event_type is not None else existing["event_type"]) - event_at_ms = ( - parse_event_at_ms(event_at) if event_at is not None else int(existing["event_at_ms"]) - ) - note_s = existing["note"] if note is None else str(note or "").strip()[:500] - now_ms = int(time.time() * 1000) - conn = _connect(db_path) - try: - _assert_no_duplicate(conn, et, event_at_ms, exclude_id=int(event_id)) - conn.execute( - """ - UPDATE macro_events - SET event_type=?, event_at_ms=?, note=?, updated_at_ms=? - WHERE id=? - """, - (et, event_at_ms, note_s, now_ms, int(event_id)), - ) - finally: - conn.close() - return get_event(event_id, db_path=db_path) - - -def delete_event(event_id: int, db_path: Path | None = None) -> bool: - init_db(db_path) - conn = _connect(db_path) - try: - cur = conn.execute("DELETE FROM macro_events WHERE id=?", (int(event_id),)) - return cur.rowcount > 0 - finally: - conn.close() - - -def list_active_alerts( - now_ms: int | None = None, - db_path: Path | None = None, -) -> list[dict[str, Any]]: - now = int(now_ms if now_ms is not None else time.time() * 1000) - lookback = now - WINDOW_BEFORE_MS - IMMINENT_BEFORE_MS - lookahead = now + WINDOW_AFTER_MS - init_db(db_path) - conn = _connect(db_path) - try: - rows = conn.execute( - """ - SELECT * FROM macro_events - WHERE event_at_ms >= ? AND event_at_ms <= ? - ORDER BY event_at_ms ASC, id ASC - """, - (lookback, lookahead), - ).fetchall() - finally: - conn.close() - alerts: list[dict[str, Any]] = [] - for row in rows: - item = enrich_alert(_row_to_dict(row), now_ms=now) - if item: - alerts.append(item) - return alerts - - -def build_banner_message(alert: dict[str, Any], *, has_positions: bool) -> str: - label = alert.get("event_type_label") or alert.get("event_type") or "宏观数据" - phase = alert.get("phase") or "window" - if has_positions: - if phase == "imminent" and int(alert.get("minutes_to_event") or 0) > 0: - return ( - f"「{label}」即将发布(约 {alert['minutes_to_event']} 分钟)," - "注意仓位风险:勿加仓,检查止损/减仓" - ) - return f"「{label}」高波动窗口(±1h),注意仓位风险:勿加仓,检查止损/减仓" - if phase == "imminent" and int(alert.get("minutes_to_event") or 0) > 0: - return ( - f"「{label}」即将发布(约 {alert['minutes_to_event']} 分钟)," - "建议等待,避免新开仓" - ) - return f"「{label}」高波动窗口(±1h),建议等待,避免新开仓" +"""中控宏观关键数据日历:手动录入 FOMC / CPI / 非农档发布时间,±1h 风控前置窗口。""" + +from __future__ import annotations + +import os +import sqlite3 +import time +from datetime import datetime +from pathlib import Path +from typing import Any +from zoneinfo import ZoneInfo + +from lib.hub.hub_symbol_archive_lib import parse_wall_clock_ms + +DISPLAY_TZ = ZoneInfo(os.getenv("APP_TIMEZONE", "Asia/Shanghai")) + +MACRO_EVENT_TYPES = ("fomc", "cpi", "employment") + +MACRO_EVENT_LABELS: dict[str, str] = { + "fomc": "FOMC 联邦基金利率", + "cpi": "美国 CPI 通胀", + "employment": "就业与劳工数据", +} + +WINDOW_BEFORE_MS = int(os.getenv("HUB_MACRO_WINDOW_BEFORE_SEC", str(3600))) * 1000 +WINDOW_AFTER_MS = int(os.getenv("HUB_MACRO_WINDOW_AFTER_SEC", str(3600))) * 1000 +IMMINENT_BEFORE_MS = int(os.getenv("HUB_MACRO_IMMINENT_BEFORE_SEC", str(1800))) * 1000 +LIST_FUTURE_DAYS = int(os.getenv("HUB_MACRO_LIST_FUTURE_DAYS", "60")) + + +def default_db_path() -> Path: + raw = (os.getenv("HUB_MACRO_CALENDAR_DB_PATH") or "").strip() + if raw: + return Path(raw) + hub_dir = Path(__file__).resolve().parent / "manual_trading_hub" / "data" + hub_dir.mkdir(parents=True, exist_ok=True) + return hub_dir / "hub_macro_calendar.db" + + +def _connect(db_path: Path | None = None) -> sqlite3.Connection: + path = db_path or default_db_path() + path.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(str(path), timeout=30, isolation_level=None) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + return conn + + +def init_db(db_path: Path | None = None) -> None: + conn = _connect(db_path) + try: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS macro_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + event_type TEXT NOT NULL, + event_at_ms INTEGER NOT NULL, + note TEXT NOT NULL DEFAULT '', + created_at_ms INTEGER NOT NULL, + updated_at_ms INTEGER NOT NULL + ) + """ + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_macro_events_at ON macro_events(event_at_ms)" + ) + finally: + conn.close() + + +def normalize_event_type(raw: str) -> str: + key = (raw or "").strip().lower() + if key not in MACRO_EVENT_TYPES: + raise ValueError(f"事件类型须为: {', '.join(MACRO_EVENT_LABELS.values())}") + return key + + +def parse_event_at_ms(raw: Any) -> int: + ms = parse_wall_clock_ms(raw, tz=DISPLAY_TZ) + if ms is None: + raise ValueError("发布时间格式错误,请使用 YYYY-MM-DD HH:MM 或 YYYY-MM-DDTHH:MM") + return int(ms) + + +def format_event_at(ms: int) -> str: + dt = datetime.fromtimestamp(ms / 1000, tz=DISPLAY_TZ) + return dt.strftime("%Y-%m-%d %H:%M") + + +def _row_to_dict(row: sqlite3.Row) -> dict[str, Any]: + ms = int(row["event_at_ms"]) + et = str(row["event_type"]) + return { + "id": int(row["id"]), + "event_type": et, + "event_type_label": MACRO_EVENT_LABELS.get(et, et), + "event_at_ms": ms, + "event_at": format_event_at(ms), + "note": str(row["note"] or ""), + "created_at_ms": int(row["created_at_ms"]), + "updated_at_ms": int(row["updated_at_ms"]), + } + + +def _window_bounds(event_at_ms: int) -> tuple[int, int]: + start = int(event_at_ms) - WINDOW_BEFORE_MS + end = int(event_at_ms) + WINDOW_AFTER_MS + return start, end + + +def enrich_alert(row: dict[str, Any], now_ms: int | None = None) -> dict[str, Any] | None: + now = int(now_ms if now_ms is not None else time.time() * 1000) + event_at_ms = int(row["event_at_ms"]) + window_start, window_end = _window_bounds(event_at_ms) + if now < window_start or now > window_end: + return None + imminent = now >= (event_at_ms - IMMINENT_BEFORE_MS) and now <= window_end + mins_to_event = max(0, int((event_at_ms - now) / 60000)) + mins_from_event = max(0, int((now - event_at_ms) / 60000)) + return { + **row, + "window_start_ms": window_start, + "window_end_ms": window_end, + "window_start": format_event_at(window_start), + "window_end": format_event_at(window_end), + "phase": "imminent" if imminent else "window", + "phase_label": "即将发布" if imminent and now < event_at_ms else "高波动窗口", + "minutes_to_event": mins_to_event if now < event_at_ms else 0, + "minutes_from_event": mins_from_event if now >= event_at_ms else 0, + } + + +def list_events( + *, + now_ms: int | None = None, + include_expired_hours: int = 24, + db_path: Path | None = None, +) -> list[dict[str, Any]]: + init_db(db_path) + now = int(now_ms if now_ms is not None else time.time() * 1000) + horizon = now + LIST_FUTURE_DAYS * 86400 * 1000 + expired_cutoff = now - max(0, int(include_expired_hours)) * 3600 * 1000 - WINDOW_AFTER_MS + conn = _connect(db_path) + try: + rows = conn.execute( + """ + SELECT * FROM macro_events + WHERE event_at_ms >= ? AND event_at_ms <= ? + ORDER BY event_at_ms ASC, id ASC + """, + (expired_cutoff, horizon), + ).fetchall() + return [_row_to_dict(r) for r in rows] + finally: + conn.close() + + +def get_event(event_id: int, db_path: Path | None = None) -> dict[str, Any] | None: + init_db(db_path) + conn = _connect(db_path) + try: + row = conn.execute("SELECT * FROM macro_events WHERE id=?", (int(event_id),)).fetchone() + return _row_to_dict(row) if row else None + finally: + conn.close() + + +def _assert_no_duplicate( + conn: sqlite3.Connection, + event_type: str, + event_at_ms: int, + *, + exclude_id: int | None = None, +) -> None: + if exclude_id is None: + row = conn.execute( + "SELECT id FROM macro_events WHERE event_type=? AND event_at_ms=? LIMIT 1", + (event_type, int(event_at_ms)), + ).fetchone() + else: + row = conn.execute( + """ + SELECT id FROM macro_events + WHERE event_type=? AND event_at_ms=? AND id<>? + LIMIT 1 + """, + (event_type, int(event_at_ms), int(exclude_id)), + ).fetchone() + if row: + raise ValueError("同类型、同发布时间的记录已存在") + + +def create_event( + event_type: str, + event_at: Any, + *, + note: str = "", + db_path: Path | None = None, +) -> dict[str, Any]: + init_db(db_path) + et = normalize_event_type(event_type) + event_at_ms = parse_event_at_ms(event_at) + note_s = str(note or "").strip()[:500] + now_ms = int(time.time() * 1000) + conn = _connect(db_path) + try: + _assert_no_duplicate(conn, et, event_at_ms) + cur = conn.execute( + """ + INSERT INTO macro_events (event_type, event_at_ms, note, created_at_ms, updated_at_ms) + VALUES (?, ?, ?, ?, ?) + """, + (et, event_at_ms, note_s, now_ms, now_ms), + ) + eid = int(cur.lastrowid) + finally: + conn.close() + row = get_event(eid, db_path=db_path) + assert row is not None + return row + + +def update_event( + event_id: int, + *, + event_type: str | None = None, + event_at: Any | None = None, + note: str | None = None, + db_path: Path | None = None, +) -> dict[str, Any] | None: + init_db(db_path) + existing = get_event(event_id, db_path=db_path) + if not existing: + return None + et = normalize_event_type(event_type if event_type is not None else existing["event_type"]) + event_at_ms = ( + parse_event_at_ms(event_at) if event_at is not None else int(existing["event_at_ms"]) + ) + note_s = existing["note"] if note is None else str(note or "").strip()[:500] + now_ms = int(time.time() * 1000) + conn = _connect(db_path) + try: + _assert_no_duplicate(conn, et, event_at_ms, exclude_id=int(event_id)) + conn.execute( + """ + UPDATE macro_events + SET event_type=?, event_at_ms=?, note=?, updated_at_ms=? + WHERE id=? + """, + (et, event_at_ms, note_s, now_ms, int(event_id)), + ) + finally: + conn.close() + return get_event(event_id, db_path=db_path) + + +def delete_event(event_id: int, db_path: Path | None = None) -> bool: + init_db(db_path) + conn = _connect(db_path) + try: + cur = conn.execute("DELETE FROM macro_events WHERE id=?", (int(event_id),)) + return cur.rowcount > 0 + finally: + conn.close() + + +def list_active_alerts( + now_ms: int | None = None, + db_path: Path | None = None, +) -> list[dict[str, Any]]: + now = int(now_ms if now_ms is not None else time.time() * 1000) + lookback = now - WINDOW_BEFORE_MS - IMMINENT_BEFORE_MS + lookahead = now + WINDOW_AFTER_MS + init_db(db_path) + conn = _connect(db_path) + try: + rows = conn.execute( + """ + SELECT * FROM macro_events + WHERE event_at_ms >= ? AND event_at_ms <= ? + ORDER BY event_at_ms ASC, id ASC + """, + (lookback, lookahead), + ).fetchall() + finally: + conn.close() + alerts: list[dict[str, Any]] = [] + for row in rows: + item = enrich_alert(_row_to_dict(row), now_ms=now) + if item: + alerts.append(item) + return alerts + + +def build_banner_message(alert: dict[str, Any], *, has_positions: bool) -> str: + label = alert.get("event_type_label") or alert.get("event_type") or "宏观数据" + phase = alert.get("phase") or "window" + if has_positions: + if phase == "imminent" and int(alert.get("minutes_to_event") or 0) > 0: + return ( + f"「{label}」即将发布(约 {alert['minutes_to_event']} 分钟)," + "注意仓位风险:勿加仓,检查止损/减仓" + ) + return f"「{label}」高波动窗口(±1h),注意仓位风险:勿加仓,检查止损/减仓" + if phase == "imminent" and int(alert.get("minutes_to_event") or 0) > 0: + return ( + f"「{label}」即将发布(约 {alert['minutes_to_event']} 分钟)," + "建议等待,避免新开仓" + ) + return f"「{label}」高波动窗口(±1h),建议等待,避免新开仓" diff --git a/hub_market_info_lib.py b/lib/hub/hub_market_info_lib.py similarity index 92% rename from hub_market_info_lib.py rename to lib/hub/hub_market_info_lib.py index 6b76d7c..4c48256 100644 --- a/hub_market_info_lib.py +++ b/lib/hub/hub_market_info_lib.py @@ -1,81 +1,81 @@ -"""实例 USDT 永续合约信息(与实盘 ccxt 精度一致)。""" - -from __future__ import annotations - -from typing import Any, Callable, Optional, Tuple - -from hub_calculator_market_lib import ( - amount_decimals_from_exchange, - normalize_base_symbol, - price_decimals_from_exchange, - resolve_usdt_perp_symbol, -) -from hub_ohlcv_lib import normalize_price_tick, price_tick_from_market - - -def fetch_usdt_swap_market_info( - *, - base_or_symbol: str, - normalize_symbol_input: Callable[[str], str], - normalize_exchange_symbol: Callable[[str], str], - ensure_markets_loaded: Callable[[], None], - exchange: Any, - exchange_id: str = "", -) -> dict[str, Any]: - """供各实例 /api/hub/market 调用。""" - raw = str(base_or_symbol or "").strip() - if not raw: - return {"ok": False, "msg": "请输入币种,如 ETH"} - - try: - ensure_markets_loaded() - except Exception as exc: - return {"ok": False, "msg": f"加载市场失败: {exc}"} - - base_u = normalize_base_symbol(raw) - hub_sym = normalize_symbol_input(raw if base_u else raw) - try: - ex_sym = normalize_exchange_symbol(hub_sym) - except Exception: - ex_sym = hub_sym - - sym, err = resolve_usdt_perp_symbol(exchange, base_u or hub_sym) - if err and ex_sym: - markets = getattr(exchange, "markets", None) or {} - if ex_sym in markets: - sym = ex_sym - err = None - if err or not sym: - return {"ok": False, "msg": err or f"未找到 {base_u or raw}/USDT 永续合约"} - - market = exchange.market(sym) - try: - contract_size = float(market.get("contractSize") or 1.0) - except (TypeError, ValueError): - contract_size = 1.0 - if contract_size <= 0: - contract_size = 1.0 - - price_tick = normalize_price_tick(price_tick_from_market(exchange, sym)) - amt_dec = amount_decimals_from_exchange(exchange, sym) - px_dec = price_decimals_from_exchange(exchange, sym, price_tick) - min_amount = None - try: - min_amount = float((market.get("limits") or {}).get("amount", {}).get("min")) - except (TypeError, ValueError): - min_amount = None - - base_out = (market.get("base") or base_u or "").upper() or base_u - return { - "ok": True, - "exchange": (exchange_id or "").strip().lower(), - "base": base_out, - "exchange_symbol": sym, - "display_symbol": f"{base_out}/USDT" if base_out else sym, - "contract_size": contract_size, - "price_tick": price_tick, - "price_decimals": px_dec, - "amount_decimals": amt_dec, - "min_amount": min_amount, - } - +"""实例 USDT 永续合约信息(与实盘 ccxt 精度一致)。""" + +from __future__ import annotations + +from typing import Any, Callable, Optional, Tuple + +from lib.hub.hub_calculator_market_lib import ( + amount_decimals_from_exchange, + normalize_base_symbol, + price_decimals_from_exchange, + resolve_usdt_perp_symbol, +) +from lib.hub.hub_ohlcv_lib import normalize_price_tick, price_tick_from_market + + +def fetch_usdt_swap_market_info( + *, + base_or_symbol: str, + normalize_symbol_input: Callable[[str], str], + normalize_exchange_symbol: Callable[[str], str], + ensure_markets_loaded: Callable[[], None], + exchange: Any, + exchange_id: str = "", +) -> dict[str, Any]: + """供各实例 /api/hub/market 调用。""" + raw = str(base_or_symbol or "").strip() + if not raw: + return {"ok": False, "msg": "请输入币种,如 ETH"} + + try: + ensure_markets_loaded() + except Exception as exc: + return {"ok": False, "msg": f"加载市场失败: {exc}"} + + base_u = normalize_base_symbol(raw) + hub_sym = normalize_symbol_input(raw if base_u else raw) + try: + ex_sym = normalize_exchange_symbol(hub_sym) + except Exception: + ex_sym = hub_sym + + sym, err = resolve_usdt_perp_symbol(exchange, base_u or hub_sym) + if err and ex_sym: + markets = getattr(exchange, "markets", None) or {} + if ex_sym in markets: + sym = ex_sym + err = None + if err or not sym: + return {"ok": False, "msg": err or f"未找到 {base_u or raw}/USDT 永续合约"} + + market = exchange.market(sym) + try: + contract_size = float(market.get("contractSize") or 1.0) + except (TypeError, ValueError): + contract_size = 1.0 + if contract_size <= 0: + contract_size = 1.0 + + price_tick = normalize_price_tick(price_tick_from_market(exchange, sym)) + amt_dec = amount_decimals_from_exchange(exchange, sym) + px_dec = price_decimals_from_exchange(exchange, sym, price_tick) + min_amount = None + try: + min_amount = float((market.get("limits") or {}).get("amount", {}).get("min")) + except (TypeError, ValueError): + min_amount = None + + base_out = (market.get("base") or base_u or "").upper() or base_u + return { + "ok": True, + "exchange": (exchange_id or "").strip().lower(), + "base": base_out, + "exchange_symbol": sym, + "display_symbol": f"{base_out}/USDT" if base_out else sym, + "contract_size": contract_size, + "price_tick": price_tick, + "price_decimals": px_dec, + "amount_decimals": amt_dec, + "min_amount": min_amount, + } + diff --git a/hub_ohlcv_lib.py b/lib/hub/hub_ohlcv_lib.py similarity index 100% rename from hub_ohlcv_lib.py rename to lib/hub/hub_ohlcv_lib.py diff --git a/hub_position_metrics.py b/lib/hub/hub_position_metrics.py similarity index 100% rename from hub_position_metrics.py rename to lib/hub/hub_position_metrics.py diff --git a/hub_sso.py b/lib/hub/hub_sso.py similarity index 100% rename from hub_sso.py rename to lib/hub/hub_sso.py diff --git a/hub_symbol_archive_lib.py b/lib/hub/hub_symbol_archive_lib.py similarity index 97% rename from hub_symbol_archive_lib.py rename to lib/hub/hub_symbol_archive_lib.py index f9e7426..d911550 100644 --- a/hub_symbol_archive_lib.py +++ b/lib/hub/hub_symbol_archive_lib.py @@ -1,1676 +1,1676 @@ -"""中控币种档案:永久 5m K 线库(建档种子 + 4h 增量),交易缓存与 overlay。""" - -from __future__ import annotations - -import json -import os -import sqlite3 -import time -from datetime import datetime, timedelta, timezone -from pathlib import Path -from typing import Any, Callable, Optional -from zoneinfo import ZoneInfo - -CHART_DISPLAY_TZ = ZoneInfo(os.getenv("APP_TIMEZONE", "Asia/Shanghai")) - -from hub_ohlcv_lib import ( - TIMEFRAME_MS, - aggregate_ohlcv_bars, - normalize_chart_timeframe, - normalize_perpetual_symbol, -) -from hub_trades_lib import ( - display_entry_type_label, - effective_hold_minutes, - format_hold_minutes, -) - -ARCHIVE_TIMEFRAMES = frozenset({"5m", "15m", "1h", "4h"}) -ARCHIVE_DEFAULT_TIMEFRAME = "15m" -ARCHIVE_SEED_LOOKBACK_DAYS = 30 -ARCHIVE_VISIBLE_BARS_DEFAULT = 200 -ARCHIVE_MAX_CANDLES: dict[str, int] = { - "5m": 9000, - "15m": 15000, - "1h": 4000, - "4h": 2000, -} -ARCHIVE_SYNC_INTERVAL_SEC = int(os.getenv("HUB_ARCHIVE_SYNC_INTERVAL_SEC", str(4 * 3600))) -ARCHIVE_TRADE_DAYS = int(os.getenv("HUB_ARCHIVE_TRADE_DAYS", "365")) -ARCHIVE_TRADE_LIMIT = int(os.getenv("HUB_ARCHIVE_TRADE_LIMIT", "2000")) -ARCHIVE_QUOTES_MAX = int(os.getenv("HUB_ARCHIVE_QUOTES_MAX", "100")) -TRADING_DAY_RESET_HOUR = int(os.getenv("TRADING_DAY_RESET_HOUR", "8")) -ARCHIVE_QUOTE_MAX_LEN = 5000 - -BEHAVIOR_TAGS = frozenset({"", "sick", "emotion"}) - - -def default_db_path() -> Path: - raw = (os.getenv("HUB_ARCHIVE_DB_PATH") or "").strip() - if raw: - return Path(raw) - hub_dir = Path(__file__).resolve().parent / "manual_trading_hub" / "data" - hub_dir.mkdir(parents=True, exist_ok=True) - return hub_dir / "hub_symbol_archive.db" - - -def _connect(db_path: Path | None = None) -> sqlite3.Connection: - path = db_path or default_db_path() - path.parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(str(path), timeout=30, isolation_level=None) - conn.row_factory = sqlite3.Row - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA synchronous=NORMAL") - return conn - - -def init_db(db_path: Path | None = None) -> None: - conn = _connect(db_path) - try: - conn.execute( - """ - CREATE TABLE IF NOT EXISTS archive_meta ( - exchange_key TEXT NOT NULL, - symbol TEXT NOT NULL, - first_trade_opened_ms INTEGER, - archive_started_at INTEGER NOT NULL, - last_kline_sync_ms INTEGER, - last_trade_sync_ms INTEGER, - seed_complete INTEGER NOT NULL DEFAULT 0, - PRIMARY KEY (exchange_key, symbol) - ) - """ - ) - conn.execute( - """ - CREATE TABLE IF NOT EXISTS archive_bars_5m ( - exchange_key TEXT NOT NULL, - symbol TEXT NOT NULL, - open_time_ms INTEGER NOT NULL, - open REAL NOT NULL, - high REAL NOT NULL, - low REAL NOT NULL, - close REAL NOT NULL, - volume REAL NOT NULL DEFAULT 0, - updated_at INTEGER NOT NULL, - PRIMARY KEY (exchange_key, symbol, open_time_ms) - ) - """ - ) - conn.execute( - """ - CREATE INDEX IF NOT EXISTS idx_archive_bars_series - ON archive_bars_5m (exchange_key, symbol, open_time_ms) - """ - ) - conn.execute( - """ - CREATE TABLE IF NOT EXISTS archive_trade_cache ( - exchange_key TEXT NOT NULL, - trade_id INTEGER NOT NULL, - symbol TEXT NOT NULL, - direction TEXT, - result TEXT, - pnl_amount REAL, - opened_at TEXT, - closed_at TEXT, - opened_at_ms INTEGER, - closed_at_ms INTEGER, - monitor_type TEXT, - entry_reason TEXT, - exchange_turnover_usdt REAL, - exchange_commission_usdt REAL, - payload_json TEXT, - synced_at INTEGER NOT NULL, - PRIMARY KEY (exchange_key, trade_id) - ) - """ - ) - conn.execute( - """ - CREATE INDEX IF NOT EXISTS idx_archive_trades_sym - ON archive_trade_cache (exchange_key, symbol, closed_at_ms) - """ - ) - conn.execute( - """ - CREATE TABLE IF NOT EXISTS trade_overlay ( - exchange_key TEXT NOT NULL, - trade_id INTEGER NOT NULL, - behavior_tag TEXT NOT NULL DEFAULT '', - note TEXT NOT NULL DEFAULT '', - updated_at INTEGER NOT NULL, - PRIMARY KEY (exchange_key, trade_id) - ) - """ - ) - conn.execute( - """ - CREATE TABLE IF NOT EXISTS archive_review_quotes ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - quote_date TEXT NOT NULL UNIQUE, - content TEXT NOT NULL DEFAULT '', - created_at INTEGER NOT NULL, - updated_at INTEGER NOT NULL - ) - """ - ) - conn.execute( - """ - CREATE INDEX IF NOT EXISTS idx_archive_quotes_date - ON archive_review_quotes (quote_date DESC) - """ - ) - for ddl in ( - "ALTER TABLE archive_trade_cache ADD COLUMN exchange_turnover_usdt REAL", - "ALTER TABLE archive_trade_cache ADD COLUMN exchange_commission_usdt REAL", - ): - try: - conn.execute(ddl) - except Exception: - pass - finally: - conn.close() - - -def _now_ms() -> int: - return int(time.time() * 1000) - - -def _optional_float(raw: Any) -> float | None: - if raw in (None, ""): - return None - try: - return float(raw) - except (TypeError, ValueError): - return None - - -def parse_wall_clock_ms(raw: Any, *, tz: ZoneInfo = CHART_DISPLAY_TZ) -> int | None: - """将 YYYY-MM-DD[ HH:MM[:SS]] 按指定时区墙钟解析为 UTC 毫秒(默认 UTC+8)。""" - if raw in (None, ""): - return None - try: - if isinstance(raw, (int, float)): - v = int(raw) - return v if v > 1_000_000_000_000 else v * 1000 - except (TypeError, ValueError): - pass - s = str(raw).strip().replace("Z", "").replace("T", " ") - if not s: - return None - if s.isdigit(): - v = int(s) - return v if v > 1_000_000_000_000 else v * 1000 - for fmt, ln in (("%Y-%m-%d %H:%M:%S", 19), ("%Y-%m-%d %H:%M", 16), ("%Y-%m-%d", 10)): - try: - dt = datetime.strptime(s[:ln], fmt) - aware = dt.replace(tzinfo=tz) - return int(aware.timestamp() * 1000) - except ValueError: - continue - return None - - -def ms_to_wall_clock_str(ms: int, *, tz: ZoneInfo = CHART_DISPLAY_TZ) -> str: - dt = datetime.fromtimestamp(int(ms) / 1000.0, tz=timezone.utc).astimezone(tz) - return dt.strftime("%Y-%m-%d %H:%M:%S") - - -def _parse_dt_ms(raw: Any) -> int | None: - return parse_wall_clock_ms(raw) - - -def _trade_entry_reason_for_cache(t: dict[str, Any]) -> str: - for key in ("entry_type", "entry_reason", "reviewed_entry_reason"): - raw = t.get(key) - if raw is not None and str(raw).strip(): - return str(raw).strip() - return display_entry_type_label(t) if isinstance(t, dict) else "" - - -def purge_stale_trades_cache( - exchange_key: str, - active_trade_ids: list[int] | set[int], - *, - db_path: Path | None = None, -) -> int: - """删除该所缓存中已不在复盘/交易记录里的条目。""" - ex_k = (exchange_key or "").strip().lower() - if not ex_k: - return 0 - ids: list[int] = [] - for raw in active_trade_ids or []: - try: - ids.append(int(raw)) - except (TypeError, ValueError): - continue - conn = _connect(db_path) - try: - if not ids: - rows = conn.execute( - "SELECT trade_id FROM archive_trade_cache WHERE exchange_key=?", - (ex_k,), - ).fetchall() - stale_ids = [int(r["trade_id"]) for r in rows] - cur = conn.execute( - "DELETE FROM archive_trade_cache WHERE exchange_key=?", - (ex_k,), - ) - else: - placeholders = ",".join("?" * len(ids)) - rows = conn.execute( - f""" - SELECT trade_id FROM archive_trade_cache - WHERE exchange_key=? AND trade_id NOT IN ({placeholders}) - """, - (ex_k, *ids), - ).fetchall() - stale_ids = [int(r["trade_id"]) for r in rows] - cur = conn.execute( - f""" - DELETE FROM archive_trade_cache - WHERE exchange_key=? AND trade_id NOT IN ({placeholders}) - """, - (ex_k, *ids), - ) - removed = int(cur.rowcount or 0) - if stale_ids: - ph2 = ",".join("?" * len(stale_ids)) - conn.execute( - f""" - DELETE FROM trade_overlay - WHERE exchange_key=? AND trade_id IN ({ph2}) - """, - (ex_k, *stale_ids), - ) - return removed - finally: - conn.close() - - -def delete_trade_from_archive( - exchange_key: str, - trade_id: int, - *, - db_path: Path | None = None, -) -> bool: - ex_k = (exchange_key or "").strip().lower() - tid = int(trade_id) - conn = _connect(db_path) - try: - cur = conn.execute( - """ - DELETE FROM archive_trade_cache - WHERE exchange_key=? AND trade_id=? - """, - (ex_k, tid), - ) - conn.execute( - "DELETE FROM trade_overlay WHERE exchange_key=? AND trade_id=?", - (ex_k, tid), - ) - return int(cur.rowcount or 0) > 0 - finally: - conn.close() - - -def upsert_trades_cache( - exchange_key: str, - trades: list[dict[str, Any]], - *, - db_path: Path | None = None, - prune_missing: bool = True, -) -> dict[str, int]: - init_db(db_path) - ex_k = (exchange_key or "").strip().lower() - if not ex_k: - return {"upserted": 0, "removed": 0} - now = _now_ms() - n = 0 - active_ids: list[int] = [] - conn = _connect(db_path) - try: - for t in trades or []: - try: - tid = int(t.get("id")) - except (TypeError, ValueError): - continue - sym = (t.get("symbol") or "").strip().upper() - if not sym: - continue - active_ids.append(tid) - row = dict(t) - row["exchange_key"] = ex_k - row.pop("account_exchange_key", None) - payload = {k: row.get(k) for k in row.keys()} - entry_label = _trade_entry_reason_for_cache(t) - conn.execute( - """ - INSERT INTO archive_trade_cache ( - exchange_key, trade_id, symbol, direction, result, pnl_amount, - opened_at, closed_at, opened_at_ms, closed_at_ms, - monitor_type, entry_reason, exchange_turnover_usdt, exchange_commission_usdt, - payload_json, synced_at - ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) - ON CONFLICT(exchange_key, trade_id) DO UPDATE SET - symbol=excluded.symbol, - direction=excluded.direction, - result=excluded.result, - pnl_amount=excluded.pnl_amount, - opened_at=excluded.opened_at, - closed_at=excluded.closed_at, - opened_at_ms=excluded.opened_at_ms, - closed_at_ms=excluded.closed_at_ms, - monitor_type=excluded.monitor_type, - entry_reason=excluded.entry_reason, - exchange_turnover_usdt=excluded.exchange_turnover_usdt, - exchange_commission_usdt=excluded.exchange_commission_usdt, - payload_json=excluded.payload_json, - synced_at=excluded.synced_at - """, - ( - ex_k, - tid, - sym, - t.get("direction"), - t.get("result"), - float(t.get("pnl_amount") or 0), - t.get("opened_at"), - t.get("closed_at"), - t.get("opened_at_ms") or _parse_dt_ms(t.get("opened_at")), - t.get("closed_at_ms") or _parse_dt_ms(t.get("closed_at")), - t.get("monitor_type"), - entry_label, - _optional_float(t.get("exchange_turnover_usdt")), - _optional_float(t.get("exchange_commission_usdt")), - json.dumps(payload, ensure_ascii=False, default=str), - now, - ), - ) - n += 1 - finally: - conn.close() - removed = 0 - if prune_missing: - removed = purge_stale_trades_cache(ex_k, active_ids, db_path=db_path) - return {"upserted": n, "removed": removed} - - -def _enrich_trade_display_fields(out: dict[str, Any]) -> dict[str, Any]: - """缓存行补齐复盘优先的展示字段(兼容旧同步数据)。""" - opened_ms = out.get("opened_at_ms") or _parse_dt_ms(out.get("opened_at")) - closed_ms = out.get("closed_at_ms") or _parse_dt_ms(out.get("closed_at")) - if opened_ms: - out["opened_at_ms"] = int(opened_ms) - if closed_ms: - out["closed_at_ms"] = int(closed_ms) - if not out.get("opened_at") and opened_ms: - out["opened_at"] = ms_to_wall_clock_str(int(opened_ms)) - if not out.get("closed_at") and closed_ms: - out["closed_at"] = ms_to_wall_clock_str(int(closed_ms)) - entry_type = display_entry_type_label(out) - if entry_type and entry_type != "—": - out["entry_type"] = entry_type - out["entry_reason"] = entry_type - hold_m = out.get("hold_minutes") - if hold_m in (None, ""): - hold_m = effective_hold_minutes( - out, - opened_ms=out.get("opened_at_ms"), - closed_ms=out.get("closed_at_ms"), - ) - try: - hold_m = max(0, int(hold_m or 0)) - except (TypeError, ValueError): - hold_m = 0 - out["hold_minutes"] = hold_m - out["hold_minutes_text"] = out.get("hold_minutes_text") or format_hold_minutes(hold_m) - if "reviewed" not in out: - out["reviewed"] = bool( - out.get("reviewed_at") - or out.get("reviewed_result") - or out.get("reviewed_opened_at") - or out.get("reviewed_closed_at") - or out.get("reviewed_entry_reason") - or out.get("reviewed_hold_minutes") - ) - return out - - -def _trade_row_to_dict(row: sqlite3.Row, overlay: dict | None = None) -> dict[str, Any]: - d = dict(row) - payload = {} - raw = d.pop("payload_json", None) - if raw: - try: - payload = json.loads(raw) - except (json.JSONDecodeError, TypeError): - payload = {} - out = {**payload, **{k: d[k] for k in d.keys() if k not in payload}} - for key in ( - "exchange_key", - "symbol", - "trade_id", - "direction", - "result", - "pnl_amount", - "opened_at", - "closed_at", - "opened_at_ms", - "closed_at_ms", - "monitor_type", - "entry_reason", - "exchange_turnover_usdt", - "exchange_commission_usdt", - "synced_at", - ): - if key in d and d[key] not in (None, ""): - out[key] = d[key] - ov = overlay or {} - out["behavior_tag"] = ov.get("behavior_tag") or "" - out["note"] = ov.get("note") or "" - out["trade_id"] = out.get("trade_id") or out.get("id") - ex_col = str(d.get("exchange_key") or "").strip().lower() - if ex_col: - out["exchange_key"] = ex_col - out.pop("account_exchange_key", None) - return _enrich_trade_display_fields(out) - - -def load_overlays( - exchange_key: str, - trade_ids: list[int] | None = None, - *, - db_path: Path | None = None, -) -> dict[int, dict[str, Any]]: - ex_k = (exchange_key or "").strip().lower() - conn = _connect(db_path) - try: - if trade_ids: - placeholders = ",".join("?" * len(trade_ids)) - rows = conn.execute( - f""" - SELECT exchange_key, trade_id, behavior_tag, note, updated_at - FROM trade_overlay - WHERE exchange_key=? AND trade_id IN ({placeholders}) - """, - (ex_k, *trade_ids), - ).fetchall() - else: - rows = conn.execute( - """ - SELECT exchange_key, trade_id, behavior_tag, note, updated_at - FROM trade_overlay WHERE exchange_key=? - """, - (ex_k,), - ).fetchall() - return { - int(r["trade_id"]): { - "behavior_tag": r["behavior_tag"] or "", - "note": r["note"] or "", - "updated_at": r["updated_at"], - } - for r in rows - } - finally: - conn.close() - - -def upsert_trade_overlay( - exchange_key: str, - trade_id: int, - *, - behavior_tag: str | None = None, - note: str | None = None, - db_path: Path | None = None, -) -> dict[str, Any]: - init_db(db_path) - ex_k = (exchange_key or "").strip().lower() - tid = int(trade_id) - tag = (behavior_tag or "").strip().lower() - if tag not in BEHAVIOR_TAGS: - tag = "" - note_text = (note or "").strip()[:2000] - now = _now_ms() - conn = _connect(db_path) - try: - conn.execute( - """ - INSERT INTO trade_overlay (exchange_key, trade_id, behavior_tag, note, updated_at) - VALUES (?,?,?,?,?) - ON CONFLICT(exchange_key, trade_id) DO UPDATE SET - behavior_tag=excluded.behavior_tag, - note=excluded.note, - updated_at=excluded.updated_at - """, - (ex_k, tid, tag, note_text, now), - ) - finally: - conn.close() - return {"exchange_key": ex_k, "trade_id": tid, "behavior_tag": tag, "note": note_text} - - -def list_symbol_rows( - *, - exchange_key: str = "", - filter_profit: bool = False, - filter_loss: bool = False, - filter_sick: bool = False, - filter_emotion: bool = False, - db_path: Path | None = None, -) -> list[dict[str, Any]]: - """一所一币一行汇总。""" - init_db(db_path) - conn = _connect(db_path) - try: - params: list[Any] = [] - where = "1=1" - ex_filter = (exchange_key or "").strip().lower() - if ex_filter: - where += " AND t.exchange_key=?" - params.append(ex_filter) - - rows = conn.execute( - f""" - SELECT t.exchange_key, t.symbol, - COUNT(*) AS trade_count, - SUM(CASE WHEN t.pnl_amount > 0.0001 THEN 1 ELSE 0 END) AS win_count, - SUM(CASE WHEN t.pnl_amount < -0.0001 THEN 1 ELSE 0 END) AS loss_count, - SUM(COALESCE(t.pnl_amount, 0)) AS total_pnl, - MIN(COALESCE(t.opened_at_ms, 0)) AS first_opened_ms, - MAX(COALESCE(t.closed_at_ms, 0)) AS last_closed_ms - FROM archive_trade_cache t - WHERE {where} - GROUP BY t.exchange_key, t.symbol - ORDER BY last_closed_ms DESC - """, - params, - ).fetchall() - - overlays_by_ex: dict[str, dict[int, dict]] = {} - out: list[dict[str, Any]] = [] - for r in rows: - ex_k = r["exchange_key"] - sym = r["symbol"] - if ex_k not in overlays_by_ex: - overlays_by_ex[ex_k] = load_overlays(ex_k, db_path=db_path) - - trade_rows = conn.execute( - """ - SELECT trade_id, pnl_amount FROM archive_trade_cache - WHERE exchange_key=? AND symbol=? - """, - (ex_k, sym), - ).fetchall() - has_profit = any(float(x["pnl_amount"] or 0) > 0.0001 for x in trade_rows) - has_loss = any(float(x["pnl_amount"] or 0) < -0.0001 for x in trade_rows) - has_sick = False - has_emotion = False - ov_map = overlays_by_ex.get(ex_k) or {} - for tr in trade_rows: - ov = ov_map.get(int(tr["trade_id"])) or {} - if ov.get("behavior_tag") == "sick": - has_sick = True - if ov.get("behavior_tag") == "emotion": - has_emotion = True - - if filter_profit and not has_profit: - continue - if filter_loss and not has_loss: - continue - if filter_sick and not has_sick: - continue - if filter_emotion and not has_emotion: - continue - - meta = conn.execute( - "SELECT seed_complete, last_kline_sync_ms FROM archive_meta WHERE exchange_key=? AND symbol=?", - (ex_k, sym), - ).fetchone() - - out.append( - { - "exchange_key": ex_k, - "symbol": sym, - "trade_count": int(r["trade_count"] or 0), - "win_count": int(r["win_count"] or 0), - "loss_count": int(r["loss_count"] or 0), - "total_pnl": round(float(r["total_pnl"] or 0), 4), - "first_opened_ms": int(r["first_opened_ms"] or 0) or None, - "last_closed_ms": int(r["last_closed_ms"] or 0) or None, - "seed_complete": bool(meta["seed_complete"]) if meta else False, - "last_kline_sync_ms": int(meta["last_kline_sync_ms"] or 0) if meta else None, - } - ) - return out - finally: - conn.close() - - -def load_symbol_trades( - exchange_key: str, - symbol: str, - *, - db_path: Path | None = None, -) -> list[dict[str, Any]]: - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - conn = _connect(db_path) - try: - rows = conn.execute( - """ - SELECT * FROM archive_trade_cache - WHERE exchange_key=? AND symbol=? - ORDER BY COALESCE(closed_at_ms, 0) DESC, trade_id DESC - """, - (ex_k, sym), - ).fetchall() - ids = [int(r["trade_id"]) for r in rows] - ov = load_overlays(ex_k, ids, db_path=db_path) - return [_trade_row_to_dict(r, ov.get(int(r["trade_id"]))) for r in rows] - finally: - conn.close() - - -def upsert_bars_5m( - exchange_key: str, - symbol: str, - bars: list[dict[str, Any]], - *, - db_path: Path | None = None, -) -> int: - init_db(db_path) - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - now = _now_ms() - n = 0 - conn = _connect(db_path) - try: - for b in bars or []: - try: - conn.execute( - """ - INSERT INTO archive_bars_5m ( - exchange_key, symbol, open_time_ms, open, high, low, close, volume, updated_at - ) VALUES (?,?,?,?,?,?,?,?,?) - ON CONFLICT(exchange_key, symbol, open_time_ms) DO UPDATE SET - open=excluded.open, - high=excluded.high, - low=excluded.low, - close=excluded.close, - volume=excluded.volume, - updated_at=excluded.updated_at - """, - ( - ex_k, - sym, - int(b["open_time_ms"]), - float(b["open"]), - float(b["high"]), - float(b["low"]), - float(b["close"]), - float(b.get("volume") or 0), - now, - ), - ) - n += 1 - except (KeyError, TypeError, ValueError): - continue - finally: - conn.close() - return n - - -def load_bars_5m_range( - exchange_key: str, - symbol: str, - start_ms: int, - end_ms: int, - *, - db_path: Path | None = None, -) -> list[dict[str, Any]]: - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - conn = _connect(db_path) - try: - rows = conn.execute( - """ - SELECT open_time_ms, open, high, low, close, volume - FROM archive_bars_5m - WHERE exchange_key=? AND symbol=? - AND open_time_ms >= ? AND open_time_ms <= ? - ORDER BY open_time_ms ASC - """, - (ex_k, sym, int(start_ms), int(end_ms)), - ).fetchall() - return [ - { - "open_time_ms": int(r["open_time_ms"]), - "open": float(r["open"]), - "high": float(r["high"]), - "low": float(r["low"]), - "close": float(r["close"]), - "volume": float(r["volume"] or 0), - } - for r in rows - ] - finally: - conn.close() - - -def _to_candles(bars: list[dict[str, Any]]) -> list[dict[str, Any]]: - out = [] - for b in bars or []: - try: - out.append( - { - "time": int(b["open_time_ms"] // 1000), - "open": float(b["open"]), - "high": float(b["high"]), - "low": float(b["low"]), - "close": float(b["close"]), - "volume": float(b.get("volume") or 0), - } - ) - except (KeyError, TypeError, ValueError): - continue - return out - - -def _snap_to_bar_grid(ts_ms: int, origin_ms: int, step_ms: int) -> int: - step = max(1, int(step_ms)) - origin = int(origin_ms) - if ts_ms <= origin: - return origin - idx = (int(ts_ms) - origin + step - 1) // step - return origin + idx * step - - -def _fill_missing_bars( - bars: list[dict[str, Any]], - period_ms: int, - start_ms: int, - end_ms: int, -) -> list[dict[str, Any]]: - """5m 缺口用上一根收盘价填平,保证聚合后 K 线时间轴连续。""" - by_ts: dict[int, dict[str, Any]] = {} - for b in bars or []: - try: - by_ts[int(b["open_time_ms"])] = b - except (KeyError, TypeError, ValueError): - continue - if not by_ts: - return [] - keys = sorted(by_ts.keys()) - step_ms = max(1, int(period_ms)) - origin = keys[0] - aligned_start = _snap_to_bar_grid(int(start_ms), origin, step_ms) - aligned_end = max(int(end_ms), keys[-1]) - out: list[dict[str, Any]] = [] - last: dict[str, Any] | None = None - for ts_key in keys: - if ts_key <= aligned_start: - last = by_ts[ts_key] - ts = aligned_start - while ts <= aligned_end: - cur = by_ts.get(ts) - if cur is not None: - last = cur - out.append(cur) - elif last is not None: - c = float(last["close"]) - out.append( - { - "open_time_ms": ts, - "open": c, - "high": c, - "low": c, - "close": c, - "volume": 0.0, - "filled": True, - } - ) - ts += step_ms - return out - - -def _archive_earliest_bar_ms( - exchange_key: str, - symbol: str, - *, - db_path: Path | None = None, -) -> int | None: - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - conn = _connect(db_path) - try: - row = conn.execute( - "SELECT MIN(open_time_ms) AS mn FROM archive_bars_5m WHERE exchange_key=? AND symbol=?", - (ex_k, sym), - ).fetchone() - if row and row["mn"] is not None: - return int(row["mn"]) - finally: - conn.close() - return None - - -def _trim_bars_for_cap( - bars: list[dict[str, Any]], - *, - end_ms: int, - max_n: int, -) -> list[dict[str, Any]]: - """超长时优先保留到平仓,再从最古老端截断。""" - if len(bars) <= max_n: - return bars - cut_end = len(bars) - for i in range(len(bars) - 1, -1, -1): - if int(bars[i]["open_time_ms"]) <= int(end_ms): - cut_end = i + 1 - break - essential = bars[:cut_end] - if len(essential) <= max_n: - return essential - return essential[len(essential) - max_n :] - - -def resolve_archive_chart( - exchange_key: str, - symbol: str, - timeframe: str = ARCHIVE_DEFAULT_TIMEFRAME, - *, - anchor_ms: int | None = None, - opened_ms: int | None = None, - closed_ms: int | None = None, - mode: str = "hold", - bars: int = ARCHIVE_VISIBLE_BARS_DEFAULT, - range_mode: str = "window", - db_path: Path | None = None, -) -> dict[str, Any]: - """从永久 5m 库聚合出档案 K 线视窗。 - - range_mode=history:建档起点 → 平仓(不含「到现在」),供拖动/缩放查看建仓前全局形态。 - """ - tf = normalize_chart_timeframe(timeframe, default=ARCHIVE_DEFAULT_TIMEFRAME) - if tf not in ARCHIVE_TIMEFRAMES: - return {"ok": False, "msg": f"档案仅支持 {', '.join(sorted(ARCHIVE_TIMEFRAMES))}"} - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - if not ex_k or not sym: - return {"ok": False, "msg": "缺少 exchange_key 或 symbol"} - - period = TIMEFRAME_MS[tf] - period_5m = TIMEFRAME_MS["5m"] - hold_open = int(opened_ms) if opened_ms else None - hold_close = int(closed_ms) if closed_ms else None - rm = (range_mode or "window").strip().lower() - if hold_open and hold_close and hold_close >= hold_open and rm == "history": - seed_back = max(0, hold_open - ARCHIVE_SEED_LOOKBACK_DAYS * 86400000) - earliest = _archive_earliest_bar_ms(ex_k, sym, db_path=db_path) - if earliest is not None: - start_ms = min(earliest, seed_back) - else: - start_ms = seed_back - end_ms = hold_close + max(period * 16, period_5m * 8) - anchor = hold_close if (mode or "hold").strip().lower() != "entry" else hold_open - elif hold_open and hold_close and hold_close >= hold_open: - hold_len = hold_close - hold_open - pad = max(period * 24, hold_len // 3, period_5m * 12) - start_ms = max(0, hold_open - pad) - end_ms = hold_close + pad - anchor = hold_close if (mode or "hold").strip().lower() != "entry" else hold_open - else: - visible = max(50, min(int(bars or ARCHIVE_VISIBLE_BARS_DEFAULT), 500)) - anchor = int(anchor_ms) if anchor_ms else _now_ms() - half = visible // 2 - start_ms = max(0, anchor - half * period) - end_ms = anchor + half * period - - raw_5m = load_bars_5m_range( - ex_k, - sym, - start_ms - period_5m * 6, - end_ms + period_5m * 6, - db_path=db_path, - ) - if not raw_5m: - return {"ok": False, "msg": "档案库暂无 K 线,请等待同步或手动刷新"} - - filled_5m = _fill_missing_bars(raw_5m, period_5m, start_ms - period_5m * 2, end_ms + period_5m * 2) - - if tf == "5m": - merged = [b for b in filled_5m if start_ms <= int(b["open_time_ms"]) <= end_ms] - else: - agg = aggregate_ohlcv_bars(filled_5m, tf) - merged = [b for b in agg if start_ms <= int(b["open_time_ms"]) <= end_ms] - - max_n = ARCHIVE_MAX_CANDLES.get(tf, 2000) - if rm == "history" and merged and len(merged) > max_n: - merged = merged[:max_n] - - candles = _to_candles(merged) - if not candles: - return {"ok": False, "msg": "视窗内无 K 线"} - - ex_sym = normalize_perpetual_symbol(sym) - return { - "ok": True, - "exchange_key": ex_k, - "symbol": sym, - "exchange_symbol": ex_sym, - "market_type": "swap", - "timeframe": tf, - "mode": (mode or "hold").strip().lower(), - "range_mode": rm, - "anchor_ms": anchor, - "opened_ms": hold_open, - "closed_ms": hold_close, - "window_start_ms": start_ms, - "window_end_ms": end_ms, - "candles": candles, - "bar_count": len(candles), - "gaps_filled": sum(1 for b in filled_5m if b.get("filled")), - } - - -def _ensure_meta( - exchange_key: str, - symbol: str, - first_opened_ms: int | None, - *, - db_path: Path | None = None, -) -> None: - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - now = _now_ms() - conn = _connect(db_path) - try: - row = conn.execute( - "SELECT first_trade_opened_ms FROM archive_meta WHERE exchange_key=? AND symbol=?", - (ex_k, sym), - ).fetchone() - if row: - if first_opened_ms and ( - not row["first_trade_opened_ms"] - or int(first_opened_ms) < int(row["first_trade_opened_ms"]) - ): - conn.execute( - """ - UPDATE archive_meta SET first_trade_opened_ms=? - WHERE exchange_key=? AND symbol=? - """, - (int(first_opened_ms), ex_k, sym), - ) - return - conn.execute( - """ - INSERT INTO archive_meta ( - exchange_key, symbol, first_trade_opened_ms, - archive_started_at, last_kline_sync_ms, last_trade_sync_ms, seed_complete - ) VALUES (?,?,?,?,?,?,0) - """, - (ex_k, sym, int(first_opened_ms) if first_opened_ms else None, now, None, None), - ) - finally: - conn.close() - - -def _mark_meta_sync( - exchange_key: str, - symbol: str, - *, - kline_ms: int | None = None, - trade_ms: int | None = None, - seed_complete: bool | None = None, - db_path: Path | None = None, -) -> None: - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - conn = _connect(db_path) - try: - sets = [] - params: list[Any] = [] - if kline_ms is not None: - sets.append("last_kline_sync_ms=?") - params.append(int(kline_ms)) - if trade_ms is not None: - sets.append("last_trade_sync_ms=?") - params.append(int(trade_ms)) - if seed_complete is not None: - sets.append("seed_complete=?") - params.append(1 if seed_complete else 0) - if not sets: - return - params.extend([ex_k, sym]) - conn.execute( - f"UPDATE archive_meta SET {', '.join(sets)} WHERE exchange_key=? AND symbol=?", - params, - ) - finally: - conn.close() - - -def fetch_remote_5m_range( - remote_fetch: Callable[..., dict[str, Any]], - symbol: str, - start_ms: int, - end_ms: int, -) -> list[dict[str, Any]]: - """经实例 /api/hub/ohlcv 分页拉取 5m。""" - period = TIMEFRAME_MS["5m"] - since = max(0, int(start_ms)) - end = int(end_ms) - merged: dict[int, dict[str, Any]] = {} - guard = 0 - while since < end and guard < 120: - guard += 1 - remote = remote_fetch(symbol=symbol, timeframe="5m", since_ms=since, limit=500) - if not remote.get("ok"): - break - batch = remote.get("bars") or [] - if not batch: - break - for b in batch: - try: - ts = int(b["open_time_ms"]) - merged[ts] = b - except (KeyError, TypeError, ValueError): - continue - last_ts = max(int(b["open_time_ms"]) for b in batch) - next_since = last_ts + period - if next_since <= since: - break - since = next_since - if last_ts >= end: - break - return [merged[k] for k in sorted(merged.keys()) if start_ms <= k <= end] - - -def seed_symbol_archive( - exchange_key: str, - symbol: str, - first_opened_ms: int, - remote_fetch: Callable[..., dict[str, Any]], - *, - db_path: Path | None = None, -) -> dict[str, Any]: - """建档:最早开仓向前 30 天 5m 种子。""" - init_db(db_path) - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - anchor = int(first_opened_ms) - start_ms = max(0, anchor - ARCHIVE_SEED_LOOKBACK_DAYS * 86400000) - end_ms = _now_ms() - _ensure_meta(ex_k, sym, anchor, db_path=db_path) - bars = fetch_remote_5m_range(remote_fetch, sym, start_ms, end_ms) - n = upsert_bars_5m(ex_k, sym, bars, db_path=db_path) - now = _now_ms() - _mark_meta_sync(ex_k, sym, kline_ms=now, seed_complete=True, db_path=db_path) - return {"ok": True, "seed_bars": n, "start_ms": start_ms, "end_ms": end_ms} - - -def sync_symbol_klines_incremental( - exchange_key: str, - symbol: str, - remote_fetch: Callable[..., dict[str, Any]], - *, - db_path: Path | None = None, -) -> dict[str, Any]: - """增量补 5m 至当前。""" - init_db(db_path) - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - conn = _connect(db_path) - try: - row = conn.execute( - "SELECT MAX(open_time_ms) AS mx FROM archive_bars_5m WHERE exchange_key=? AND symbol=?", - (ex_k, sym), - ).fetchone() - last_bar = int(row["mx"]) if row and row["mx"] else None - finally: - conn.close() - - period = TIMEFRAME_MS["5m"] - start_ms = max(0, (last_bar + period) if last_bar else 0) - end_ms = _now_ms() - if start_ms >= end_ms - period: - return {"ok": True, "appended": 0, "skipped": True} - bars = fetch_remote_5m_range(remote_fetch, sym, start_ms, end_ms) - n = upsert_bars_5m(ex_k, sym, bars, db_path=db_path) - now = _now_ms() - _mark_meta_sync(ex_k, sym, kline_ms=now, db_path=db_path) - return {"ok": True, "appended": n, "start_ms": start_ms, "end_ms": end_ms} - - -def sync_exchange_symbol_archives( - exchange_key: str, - trades: list[dict[str, Any]], - remote_fetch: Callable[..., dict[str, Any]], - *, - db_path: Path | None = None, -) -> dict[str, Any]: - """同步单所:交易缓存 + 各币种 K 线种子/增量。""" - ex_k = (exchange_key or "").strip().lower() - cache_stats = upsert_trades_cache(ex_k, trades, db_path=db_path, prune_missing=True) - - by_sym: dict[str, int] = {} - for t in trades or []: - sym = (t.get("symbol") or "").strip().upper() - if not sym: - continue - oms = t.get("opened_at_ms") or _parse_dt_ms(t.get("opened_at")) - if oms: - cur = by_sym.get(sym) - if cur is None or int(oms) < cur: - by_sym[sym] = int(oms) - - seeded = 0 - appended = 0 - for sym, first_ms in by_sym.items(): - _ensure_meta(ex_k, sym, first_ms, db_path=db_path) - conn = _connect(db_path) - try: - meta = conn.execute( - "SELECT seed_complete FROM archive_meta WHERE exchange_key=? AND symbol=?", - (ex_k, sym), - ).fetchone() - finally: - conn.close() - if not meta or not int(meta["seed_complete"] or 0): - r = seed_symbol_archive(ex_k, sym, first_ms, remote_fetch, db_path=db_path) - seeded += int(r.get("seed_bars") or 0) - else: - r = sync_symbol_klines_incremental(ex_k, sym, remote_fetch, db_path=db_path) - appended += int(r.get("appended") or 0) - - return { - "ok": True, - "exchange_key": ex_k, - "symbols": len(by_sym), - "trades_upserted": int(cache_stats.get("upserted") or 0), - "trades_removed": int(cache_stats.get("removed") or 0), - "seed_bars": seeded, - "appended_bars": appended, - "trades": len(trades or []), - } - - -def ms_to_trading_day( - ms: int | None, - *, - reset_hour: int = TRADING_DAY_RESET_HOUR, - tz: ZoneInfo = CHART_DISPLAY_TZ, -) -> str | None: - if ms is None: - return None - try: - dt = datetime.fromtimestamp(int(ms) / 1000.0, tz=timezone.utc).astimezone(tz) - except (TypeError, ValueError, OSError): - return None - if dt.hour < reset_hour: - dt = dt - timedelta(days=1) - return dt.strftime("%Y-%m-%d") - - -def today_trading_day(*, reset_hour: int = TRADING_DAY_RESET_HOUR) -> str: - return ms_to_trading_day(_now_ms(), reset_hour=reset_hour) or datetime.now( - CHART_DISPLAY_TZ - ).strftime("%Y-%m-%d") - - -def trading_day_bounds_ms( - trading_day: str, - *, - reset_hour: int = TRADING_DAY_RESET_HOUR, - tz: ZoneInfo = CHART_DISPLAY_TZ, -) -> tuple[int, int]: - day = datetime.strptime((trading_day or "").strip()[:10], "%Y-%m-%d") - start = day.replace(hour=reset_hour, minute=0, second=0, microsecond=0, tzinfo=tz) - end = start + timedelta(days=1) - return int(start.timestamp() * 1000), int(end.timestamp() * 1000) - - -def resolve_period_bounds( - *, - period: str = "", - trading_day: str = "", - date_from: str = "", - date_to: str = "", - reset_hour: int = TRADING_DAY_RESET_HOUR, -) -> tuple[int, int, str, str, str]: - """返回 (start_ms, end_ms, date_from, date_to, period_label)。""" - td = today_trading_day(reset_hour=reset_hour) - p = (period or "today").strip().lower() - if p in ("day", "today", ""): - d = (trading_day or "").strip()[:10] or td - start_ms, end_ms = trading_day_bounds_ms(d, reset_hour=reset_hour) - return start_ms, end_ms, d, d, f"本日 {d}" - if p == "week": - day_dt = datetime.strptime(td, "%Y-%m-%d") - monday = day_dt - timedelta(days=day_dt.weekday()) - df = monday.strftime("%Y-%m-%d") - start_ms, _ = trading_day_bounds_ms(df, reset_hour=reset_hour) - _, end_ms = trading_day_bounds_ms(td, reset_hour=reset_hour) - return start_ms, end_ms, df, td, f"本周 {df}~{td}" - if p == "month": - day_dt = datetime.strptime(td, "%Y-%m-%d") - first = day_dt.replace(day=1) - df = first.strftime("%Y-%m-%d") - start_ms, _ = trading_day_bounds_ms(df, reset_hour=reset_hour) - _, end_ms = trading_day_bounds_ms(td, reset_hour=reset_hour) - return start_ms, end_ms, df, td, f"本月 {df}~{td}" - if p == "range": - df = (date_from or "").strip()[:10] or td - dt = (date_to or "").strip()[:10] or df - if df > dt: - df, dt = dt, df - start_ms, _ = trading_day_bounds_ms(df, reset_hour=reset_hour) - _, end_ms = trading_day_bounds_ms(dt, reset_hour=reset_hour) - label = f"区间 {df}~{dt}" if df != dt else f"区间 {df}" - return start_ms, end_ms, df, dt, label - d = (trading_day or "").strip()[:10] or td - start_ms, end_ms = trading_day_bounds_ms(d, reset_hour=reset_hour) - return start_ms, end_ms, d, d, f"本日 {d}" - - -def _pnl_side(pnl: float) -> str: - if pnl > 0.0001: - return "win" - if pnl < -0.0001: - return "loss" - return "flat" - - -def _empty_pnl_bucket() -> dict[str, Any]: - return { - "open_count": 0, - "sick_count": 0, - "pnl_total": 0.0, - "pnl_ex_sick": 0.0, - "turnover_total": 0.0, - "commission_total": 0.0, - "win_count": 0, - "loss_count": 0, - "avg_win": None, - "avg_loss": None, - "max_win": None, - "max_loss": None, - } - - -def _finalize_pnl_bucket(bucket: dict[str, Any]) -> None: - wins = bucket.pop("_wins", []) - losses = bucket.pop("_losses", []) - open_count = int(bucket.get("open_count") or 0) - win_count = len(wins) - bucket["win_count"] = win_count - bucket["loss_count"] = len(losses) - bucket["avg_win"] = round(sum(wins) / len(wins), 4) if wins else None - avg_loss = round(sum(losses) / len(losses), 4) if losses else None - bucket["avg_loss"] = avg_loss - bucket["max_win"] = round(max(wins), 4) if wins else None - bucket["max_loss"] = round(min(losses), 4) if losses else None - bucket["pnl_total"] = round(float(bucket.get("pnl_total") or 0), 4) - bucket["pnl_ex_sick"] = round(float(bucket.get("pnl_ex_sick") or 0), 4) - bucket["turnover_total"] = round(float(bucket.get("turnover_total") or 0), 4) - bucket["commission_total"] = round(float(bucket.get("commission_total") or 0), 4) - bucket["win_rate"] = round(win_count / open_count * 100, 1) if open_count else None - avg_win = bucket["avg_win"] - if avg_win is not None and avg_loss is not None and avg_loss != 0: - bucket["profit_loss_ratio"] = round(avg_win / abs(avg_loss), 2) - else: - bucket["profit_loss_ratio"] = None - - -def _accumulate_trade_stat( - bucket: dict[str, Any], - *, - pnl: float, - is_sick: bool, - turnover: float = 0.0, - commission: float = 0.0, -) -> None: - bucket["open_count"] += 1 - bucket["pnl_total"] += pnl - bucket["turnover_total"] += turnover - bucket["commission_total"] += commission - if is_sick: - bucket["sick_count"] += 1 - else: - bucket["pnl_ex_sick"] += pnl - side = _pnl_side(pnl) - if side == "win": - bucket.setdefault("_wins", []).append(pnl) - elif side == "loss": - bucket.setdefault("_losses", []).append(pnl) - - -def _compute_period_stats(trade_rows: list[dict[str, Any]]) -> dict[str, Any]: - total_bucket = _empty_pnl_bucket() - by_ex: dict[str, dict[str, Any]] = {} - for td_row in trade_rows: - ex = str(td_row.get("exchange_key") or "?") - pnl = float(td_row.get("pnl_amount") or 0) - tag = str(td_row.get("behavior_tag") or "") - is_sick = tag == "sick" - turnover = float(td_row.get("exchange_turnover_usdt") or 0) - commission = float(td_row.get("exchange_commission_usdt") or 0) - _accumulate_trade_stat( - total_bucket, pnl=pnl, is_sick=is_sick, turnover=turnover, commission=commission - ) - if ex not in by_ex: - by_ex[ex] = _empty_pnl_bucket() - _accumulate_trade_stat( - by_ex[ex], pnl=pnl, is_sick=is_sick, turnover=turnover, commission=commission - ) - _finalize_pnl_bucket(total_bucket) - for ex in by_ex: - _finalize_pnl_bucket(by_ex[ex]) - total = int(total_bucket["open_count"] or 0) - sick = int(total_bucket["sick_count"] or 0) - sick_pct = round(sick / total * 100, 1) if total else 0.0 - return { - "open_count": total, - "sick_count": sick, - "sick_pct": sick_pct, - "pnl_total": total_bucket["pnl_total"], - "pnl_ex_sick": total_bucket["pnl_ex_sick"], - "win_count": total_bucket["win_count"], - "loss_count": total_bucket["loss_count"], - "avg_win": total_bucket["avg_win"], - "avg_loss": total_bucket["avg_loss"], - "max_win": total_bucket["max_win"], - "max_loss": total_bucket["max_loss"], - "win_rate": total_bucket["win_rate"], - "profit_loss_ratio": total_bucket["profit_loss_ratio"], - "turnover_total": total_bucket["turnover_total"], - "commission_total": total_bucket["commission_total"], - "by_exchange": by_ex, - } - - -def list_review_quotes(*, db_path: Path | None = None) -> list[dict[str, Any]]: - init_db(db_path) - conn = _connect(db_path) - try: - rows = conn.execute( - """ - SELECT id, quote_date, content, created_at, updated_at - FROM archive_review_quotes - ORDER BY quote_date DESC - LIMIT ? - """, - (ARCHIVE_QUOTES_MAX,), - ).fetchall() - return [dict(r) for r in rows] - finally: - conn.close() - - -def create_review_quote( - quote_date: str, - content: str, - *, - db_path: Path | None = None, -) -> dict[str, Any]: - init_db(db_path) - qd = (quote_date or "").strip()[:10] - if not qd: - raise ValueError("缺少 quote_date") - text = (content or "").strip() - if not text: - raise ValueError("语录内容不能为空") - if len(text) > ARCHIVE_QUOTE_MAX_LEN: - raise ValueError(f"语录最长 {ARCHIVE_QUOTE_MAX_LEN} 字") - conn = _connect(db_path) - try: - cnt = conn.execute("SELECT COUNT(*) AS c FROM archive_review_quotes").fetchone() - if int(cnt["c"] or 0) >= ARCHIVE_QUOTES_MAX: - raise ValueError(f"复盘语录最多保存 {ARCHIVE_QUOTES_MAX} 条") - now = _now_ms() - try: - cur = conn.execute( - """ - INSERT INTO archive_review_quotes (quote_date, content, created_at, updated_at) - VALUES (?,?,?,?) - """, - (qd, text, now, now), - ) - except sqlite3.IntegrityError as e: - raise ValueError("该日期已有语录,请展开编辑") from e - rid = int(cur.lastrowid) - row = conn.execute( - "SELECT id, quote_date, content, created_at, updated_at FROM archive_review_quotes WHERE id=?", - (rid,), - ).fetchone() - return dict(row) - finally: - conn.close() - - -def update_review_quote( - quote_id: int, - *, - quote_date: str | None = None, - content: str | None = None, - db_path: Path | None = None, -) -> dict[str, Any] | None: - init_db(db_path) - conn = _connect(db_path) - try: - row = conn.execute( - "SELECT id, quote_date, content FROM archive_review_quotes WHERE id=?", - (int(quote_id),), - ).fetchone() - if not row: - return None - qd = (quote_date or row["quote_date"] or "").strip()[:10] - text = (content if content is not None else row["content"] or "").strip() - if not qd or not text: - raise ValueError("日期与内容均不能为空") - if len(text) > ARCHIVE_QUOTE_MAX_LEN: - raise ValueError(f"语录最长 {ARCHIVE_QUOTE_MAX_LEN} 字") - now = _now_ms() - conn.execute( - """ - UPDATE archive_review_quotes - SET quote_date=?, content=?, updated_at=? - WHERE id=? - """, - (qd, text, now, int(quote_id)), - ) - out = conn.execute( - "SELECT id, quote_date, content, created_at, updated_at FROM archive_review_quotes WHERE id=?", - (int(quote_id),), - ).fetchone() - return dict(out) if out else None - finally: - conn.close() - - -def delete_review_quote(quote_id: int, *, db_path: Path | None = None) -> bool: - init_db(db_path) - conn = _connect(db_path) - try: - cur = conn.execute( - "DELETE FROM archive_review_quotes WHERE id=?", - (int(quote_id),), - ) - return int(cur.rowcount or 0) > 0 - finally: - conn.close() - - -def list_daily_trades( - trading_day: str = "", - *, - period: str = "", - date_from: str = "", - date_to: str = "", - exchange_key: str = "", - filter_profit: bool = False, - filter_loss: bool = False, - filter_sick: bool = False, - search: str = "", - db_path: Path | None = None, -) -> dict[str, Any]: - """按日期区间列出平仓记录(本日/本周/本月/自选,以平仓时间计),含犯病与盈亏统计。""" - init_db(db_path) - p = (period or "today").strip().lower() or "today" - start_ms, end_ms, df, dt, period_label = resolve_period_bounds( - period=p, - trading_day=trading_day, - date_from=date_from, - date_to=date_to, - ) - ex_filter = (exchange_key or "").strip().lower() - conn = _connect(db_path) - try: - params: list[Any] = [start_ms, end_ms] - where = "closed_at_ms IS NOT NULL AND closed_at_ms >= ? AND closed_at_ms < ?" - if ex_filter: - where += " AND exchange_key=?" - params.append(ex_filter) - rows = conn.execute( - f""" - SELECT * FROM archive_trade_cache - WHERE {where} - ORDER BY closed_at_ms DESC, trade_id DESC - """, - params, - ).fetchall() - overlays_by_ex: dict[str, dict[int, dict]] = {} - trades: list[dict[str, Any]] = [] - q = (search or "").strip().lower() - for r in rows: - ex_k = r["exchange_key"] - if ex_k not in overlays_by_ex: - overlays_by_ex[ex_k] = load_overlays(ex_k, db_path=db_path) - td_row = _trade_row_to_dict(r, overlays_by_ex[ex_k].get(int(r["trade_id"]))) - pnl = float(td_row.get("pnl_amount") or 0) - tag = td_row.get("behavior_tag") or "" - if filter_profit and pnl <= 0.0001: - continue - if filter_loss and pnl >= -0.0001: - continue - if filter_sick and tag != "sick": - continue - if q: - blob = " ".join( - str(td_row.get(k) or "") - for k in ( - "symbol", - "exchange_key", - "direction", - "result", - "note", - "monitor_type", - "entry_reason", - ) - ).lower() - if q not in blob: - continue - trades.append(td_row) - return { - "period": p, - "period_label": period_label, - "trading_day": dt, - "date_from": df, - "date_to": dt, - "trades": trades, - "stats": _compute_period_stats(trades), - } - finally: - conn.close() - - -def list_archive_calendar( - year: int, - month: int, - *, - exchange_key: str = "", - db_path: Path | None = None, - reset_hour: int = TRADING_DAY_RESET_HOUR, -) -> dict[str, Any]: - """按月返回每个交易日的盈亏、笔数、犯病标记(08:00 切日)。""" - init_db(db_path) - y = int(year) - m = int(month) - if m < 1 or m > 12: - raise ValueError("month 无效") - first = f"{y:04d}-{m:02d}-01" - if m == 12: - next_first = datetime(y + 1, 1, 1) - else: - next_first = datetime(y, m + 1, 1) - last = (next_first - timedelta(days=1)).strftime("%Y-%m-%d") - start_ms, _ = trading_day_bounds_ms(first, reset_hour=reset_hour) - _, end_ms = trading_day_bounds_ms(last, reset_hour=reset_hour) - ex_filter = (exchange_key or "").strip().lower() - conn = _connect(db_path) - try: - params: list[Any] = [start_ms, end_ms] - where = "closed_at_ms IS NOT NULL AND closed_at_ms >= ? AND closed_at_ms < ?" - if ex_filter: - where += " AND exchange_key=?" - params.append(ex_filter) - rows = conn.execute( - f"SELECT * FROM archive_trade_cache WHERE {where}", - params, - ).fetchall() - overlays_by_ex: dict[str, dict[int, dict]] = {} - days: dict[str, dict[str, Any]] = {} - for r in rows: - ex_k = r["exchange_key"] - if ex_k not in overlays_by_ex: - overlays_by_ex[ex_k] = load_overlays(ex_k, db_path=db_path) - td_row = _trade_row_to_dict(r, overlays_by_ex[ex_k].get(int(r["trade_id"]))) - closed_ms = td_row.get("closed_at_ms") or _parse_dt_ms(td_row.get("closed_at")) - if not closed_ms: - continue - day = ms_to_trading_day(int(closed_ms), reset_hour=reset_hour) - if not day: - continue - if day < first or day > last: - continue - bucket = days.setdefault( - day, - { - "trading_day": day, - "open_count": 0, - "sick_count": 0, - "pnl_total": 0.0, - "turnover_total": 0.0, - "commission_total": 0.0, - "has_sick": False, - }, - ) - pnl = float(td_row.get("pnl_amount") or 0) - tag = str(td_row.get("behavior_tag") or "") - is_sick = tag == "sick" - bucket["open_count"] += 1 - bucket["pnl_total"] += pnl - bucket["turnover_total"] += float(td_row.get("exchange_turnover_usdt") or 0) - bucket["commission_total"] += float(td_row.get("exchange_commission_usdt") or 0) - if is_sick: - bucket["sick_count"] += 1 - bucket["has_sick"] = True - for d in days.values(): - d["pnl_total"] = round(float(d["pnl_total"]), 4) - d["turnover_total"] = round(float(d["turnover_total"]), 4) - d["commission_total"] = round(float(d["commission_total"]), 4) - month_pnl = sum(float(d["pnl_total"]) for d in days.values()) - month_count = sum(int(d["open_count"]) for d in days.values()) - return { - "year": y, - "month": m, - "date_from": first, - "date_to": last, - "days": days, - "month_pnl_total": round(month_pnl, 4), - "month_open_count": month_count, - } - finally: - conn.close() +"""中控币种档案:永久 5m K 线库(建档种子 + 4h 增量),交易缓存与 overlay。""" + +from __future__ import annotations + +import json +import os +import sqlite3 +import time +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Any, Callable, Optional +from zoneinfo import ZoneInfo + +CHART_DISPLAY_TZ = ZoneInfo(os.getenv("APP_TIMEZONE", "Asia/Shanghai")) + +from lib.hub.hub_ohlcv_lib import ( + TIMEFRAME_MS, + aggregate_ohlcv_bars, + normalize_chart_timeframe, + normalize_perpetual_symbol, +) +from lib.hub.hub_trades_lib import ( + display_entry_type_label, + effective_hold_minutes, + format_hold_minutes, +) + +ARCHIVE_TIMEFRAMES = frozenset({"5m", "15m", "1h", "4h"}) +ARCHIVE_DEFAULT_TIMEFRAME = "15m" +ARCHIVE_SEED_LOOKBACK_DAYS = 30 +ARCHIVE_VISIBLE_BARS_DEFAULT = 200 +ARCHIVE_MAX_CANDLES: dict[str, int] = { + "5m": 9000, + "15m": 15000, + "1h": 4000, + "4h": 2000, +} +ARCHIVE_SYNC_INTERVAL_SEC = int(os.getenv("HUB_ARCHIVE_SYNC_INTERVAL_SEC", str(4 * 3600))) +ARCHIVE_TRADE_DAYS = int(os.getenv("HUB_ARCHIVE_TRADE_DAYS", "365")) +ARCHIVE_TRADE_LIMIT = int(os.getenv("HUB_ARCHIVE_TRADE_LIMIT", "2000")) +ARCHIVE_QUOTES_MAX = int(os.getenv("HUB_ARCHIVE_QUOTES_MAX", "100")) +TRADING_DAY_RESET_HOUR = int(os.getenv("TRADING_DAY_RESET_HOUR", "8")) +ARCHIVE_QUOTE_MAX_LEN = 5000 + +BEHAVIOR_TAGS = frozenset({"", "sick", "emotion"}) + + +def default_db_path() -> Path: + raw = (os.getenv("HUB_ARCHIVE_DB_PATH") or "").strip() + if raw: + return Path(raw) + hub_dir = Path(__file__).resolve().parent / "manual_trading_hub" / "data" + hub_dir.mkdir(parents=True, exist_ok=True) + return hub_dir / "hub_symbol_archive.db" + + +def _connect(db_path: Path | None = None) -> sqlite3.Connection: + path = db_path or default_db_path() + path.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(str(path), timeout=30, isolation_level=None) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + return conn + + +def init_db(db_path: Path | None = None) -> None: + conn = _connect(db_path) + try: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS archive_meta ( + exchange_key TEXT NOT NULL, + symbol TEXT NOT NULL, + first_trade_opened_ms INTEGER, + archive_started_at INTEGER NOT NULL, + last_kline_sync_ms INTEGER, + last_trade_sync_ms INTEGER, + seed_complete INTEGER NOT NULL DEFAULT 0, + PRIMARY KEY (exchange_key, symbol) + ) + """ + ) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS archive_bars_5m ( + exchange_key TEXT NOT NULL, + symbol TEXT NOT NULL, + open_time_ms INTEGER NOT NULL, + open REAL NOT NULL, + high REAL NOT NULL, + low REAL NOT NULL, + close REAL NOT NULL, + volume REAL NOT NULL DEFAULT 0, + updated_at INTEGER NOT NULL, + PRIMARY KEY (exchange_key, symbol, open_time_ms) + ) + """ + ) + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_archive_bars_series + ON archive_bars_5m (exchange_key, symbol, open_time_ms) + """ + ) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS archive_trade_cache ( + exchange_key TEXT NOT NULL, + trade_id INTEGER NOT NULL, + symbol TEXT NOT NULL, + direction TEXT, + result TEXT, + pnl_amount REAL, + opened_at TEXT, + closed_at TEXT, + opened_at_ms INTEGER, + closed_at_ms INTEGER, + monitor_type TEXT, + entry_reason TEXT, + exchange_turnover_usdt REAL, + exchange_commission_usdt REAL, + payload_json TEXT, + synced_at INTEGER NOT NULL, + PRIMARY KEY (exchange_key, trade_id) + ) + """ + ) + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_archive_trades_sym + ON archive_trade_cache (exchange_key, symbol, closed_at_ms) + """ + ) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS trade_overlay ( + exchange_key TEXT NOT NULL, + trade_id INTEGER NOT NULL, + behavior_tag TEXT NOT NULL DEFAULT '', + note TEXT NOT NULL DEFAULT '', + updated_at INTEGER NOT NULL, + PRIMARY KEY (exchange_key, trade_id) + ) + """ + ) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS archive_review_quotes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + quote_date TEXT NOT NULL UNIQUE, + content TEXT NOT NULL DEFAULT '', + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + ) + """ + ) + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_archive_quotes_date + ON archive_review_quotes (quote_date DESC) + """ + ) + for ddl in ( + "ALTER TABLE archive_trade_cache ADD COLUMN exchange_turnover_usdt REAL", + "ALTER TABLE archive_trade_cache ADD COLUMN exchange_commission_usdt REAL", + ): + try: + conn.execute(ddl) + except Exception: + pass + finally: + conn.close() + + +def _now_ms() -> int: + return int(time.time() * 1000) + + +def _optional_float(raw: Any) -> float | None: + if raw in (None, ""): + return None + try: + return float(raw) + except (TypeError, ValueError): + return None + + +def parse_wall_clock_ms(raw: Any, *, tz: ZoneInfo = CHART_DISPLAY_TZ) -> int | None: + """将 YYYY-MM-DD[ HH:MM[:SS]] 按指定时区墙钟解析为 UTC 毫秒(默认 UTC+8)。""" + if raw in (None, ""): + return None + try: + if isinstance(raw, (int, float)): + v = int(raw) + return v if v > 1_000_000_000_000 else v * 1000 + except (TypeError, ValueError): + pass + s = str(raw).strip().replace("Z", "").replace("T", " ") + if not s: + return None + if s.isdigit(): + v = int(s) + return v if v > 1_000_000_000_000 else v * 1000 + for fmt, ln in (("%Y-%m-%d %H:%M:%S", 19), ("%Y-%m-%d %H:%M", 16), ("%Y-%m-%d", 10)): + try: + dt = datetime.strptime(s[:ln], fmt) + aware = dt.replace(tzinfo=tz) + return int(aware.timestamp() * 1000) + except ValueError: + continue + return None + + +def ms_to_wall_clock_str(ms: int, *, tz: ZoneInfo = CHART_DISPLAY_TZ) -> str: + dt = datetime.fromtimestamp(int(ms) / 1000.0, tz=timezone.utc).astimezone(tz) + return dt.strftime("%Y-%m-%d %H:%M:%S") + + +def _parse_dt_ms(raw: Any) -> int | None: + return parse_wall_clock_ms(raw) + + +def _trade_entry_reason_for_cache(t: dict[str, Any]) -> str: + for key in ("entry_type", "entry_reason", "reviewed_entry_reason"): + raw = t.get(key) + if raw is not None and str(raw).strip(): + return str(raw).strip() + return display_entry_type_label(t) if isinstance(t, dict) else "" + + +def purge_stale_trades_cache( + exchange_key: str, + active_trade_ids: list[int] | set[int], + *, + db_path: Path | None = None, +) -> int: + """删除该所缓存中已不在复盘/交易记录里的条目。""" + ex_k = (exchange_key or "").strip().lower() + if not ex_k: + return 0 + ids: list[int] = [] + for raw in active_trade_ids or []: + try: + ids.append(int(raw)) + except (TypeError, ValueError): + continue + conn = _connect(db_path) + try: + if not ids: + rows = conn.execute( + "SELECT trade_id FROM archive_trade_cache WHERE exchange_key=?", + (ex_k,), + ).fetchall() + stale_ids = [int(r["trade_id"]) for r in rows] + cur = conn.execute( + "DELETE FROM archive_trade_cache WHERE exchange_key=?", + (ex_k,), + ) + else: + placeholders = ",".join("?" * len(ids)) + rows = conn.execute( + f""" + SELECT trade_id FROM archive_trade_cache + WHERE exchange_key=? AND trade_id NOT IN ({placeholders}) + """, + (ex_k, *ids), + ).fetchall() + stale_ids = [int(r["trade_id"]) for r in rows] + cur = conn.execute( + f""" + DELETE FROM archive_trade_cache + WHERE exchange_key=? AND trade_id NOT IN ({placeholders}) + """, + (ex_k, *ids), + ) + removed = int(cur.rowcount or 0) + if stale_ids: + ph2 = ",".join("?" * len(stale_ids)) + conn.execute( + f""" + DELETE FROM trade_overlay + WHERE exchange_key=? AND trade_id IN ({ph2}) + """, + (ex_k, *stale_ids), + ) + return removed + finally: + conn.close() + + +def delete_trade_from_archive( + exchange_key: str, + trade_id: int, + *, + db_path: Path | None = None, +) -> bool: + ex_k = (exchange_key or "").strip().lower() + tid = int(trade_id) + conn = _connect(db_path) + try: + cur = conn.execute( + """ + DELETE FROM archive_trade_cache + WHERE exchange_key=? AND trade_id=? + """, + (ex_k, tid), + ) + conn.execute( + "DELETE FROM trade_overlay WHERE exchange_key=? AND trade_id=?", + (ex_k, tid), + ) + return int(cur.rowcount or 0) > 0 + finally: + conn.close() + + +def upsert_trades_cache( + exchange_key: str, + trades: list[dict[str, Any]], + *, + db_path: Path | None = None, + prune_missing: bool = True, +) -> dict[str, int]: + init_db(db_path) + ex_k = (exchange_key or "").strip().lower() + if not ex_k: + return {"upserted": 0, "removed": 0} + now = _now_ms() + n = 0 + active_ids: list[int] = [] + conn = _connect(db_path) + try: + for t in trades or []: + try: + tid = int(t.get("id")) + except (TypeError, ValueError): + continue + sym = (t.get("symbol") or "").strip().upper() + if not sym: + continue + active_ids.append(tid) + row = dict(t) + row["exchange_key"] = ex_k + row.pop("account_exchange_key", None) + payload = {k: row.get(k) for k in row.keys()} + entry_label = _trade_entry_reason_for_cache(t) + conn.execute( + """ + INSERT INTO archive_trade_cache ( + exchange_key, trade_id, symbol, direction, result, pnl_amount, + opened_at, closed_at, opened_at_ms, closed_at_ms, + monitor_type, entry_reason, exchange_turnover_usdt, exchange_commission_usdt, + payload_json, synced_at + ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) + ON CONFLICT(exchange_key, trade_id) DO UPDATE SET + symbol=excluded.symbol, + direction=excluded.direction, + result=excluded.result, + pnl_amount=excluded.pnl_amount, + opened_at=excluded.opened_at, + closed_at=excluded.closed_at, + opened_at_ms=excluded.opened_at_ms, + closed_at_ms=excluded.closed_at_ms, + monitor_type=excluded.monitor_type, + entry_reason=excluded.entry_reason, + exchange_turnover_usdt=excluded.exchange_turnover_usdt, + exchange_commission_usdt=excluded.exchange_commission_usdt, + payload_json=excluded.payload_json, + synced_at=excluded.synced_at + """, + ( + ex_k, + tid, + sym, + t.get("direction"), + t.get("result"), + float(t.get("pnl_amount") or 0), + t.get("opened_at"), + t.get("closed_at"), + t.get("opened_at_ms") or _parse_dt_ms(t.get("opened_at")), + t.get("closed_at_ms") or _parse_dt_ms(t.get("closed_at")), + t.get("monitor_type"), + entry_label, + _optional_float(t.get("exchange_turnover_usdt")), + _optional_float(t.get("exchange_commission_usdt")), + json.dumps(payload, ensure_ascii=False, default=str), + now, + ), + ) + n += 1 + finally: + conn.close() + removed = 0 + if prune_missing: + removed = purge_stale_trades_cache(ex_k, active_ids, db_path=db_path) + return {"upserted": n, "removed": removed} + + +def _enrich_trade_display_fields(out: dict[str, Any]) -> dict[str, Any]: + """缓存行补齐复盘优先的展示字段(兼容旧同步数据)。""" + opened_ms = out.get("opened_at_ms") or _parse_dt_ms(out.get("opened_at")) + closed_ms = out.get("closed_at_ms") or _parse_dt_ms(out.get("closed_at")) + if opened_ms: + out["opened_at_ms"] = int(opened_ms) + if closed_ms: + out["closed_at_ms"] = int(closed_ms) + if not out.get("opened_at") and opened_ms: + out["opened_at"] = ms_to_wall_clock_str(int(opened_ms)) + if not out.get("closed_at") and closed_ms: + out["closed_at"] = ms_to_wall_clock_str(int(closed_ms)) + entry_type = display_entry_type_label(out) + if entry_type and entry_type != "—": + out["entry_type"] = entry_type + out["entry_reason"] = entry_type + hold_m = out.get("hold_minutes") + if hold_m in (None, ""): + hold_m = effective_hold_minutes( + out, + opened_ms=out.get("opened_at_ms"), + closed_ms=out.get("closed_at_ms"), + ) + try: + hold_m = max(0, int(hold_m or 0)) + except (TypeError, ValueError): + hold_m = 0 + out["hold_minutes"] = hold_m + out["hold_minutes_text"] = out.get("hold_minutes_text") or format_hold_minutes(hold_m) + if "reviewed" not in out: + out["reviewed"] = bool( + out.get("reviewed_at") + or out.get("reviewed_result") + or out.get("reviewed_opened_at") + or out.get("reviewed_closed_at") + or out.get("reviewed_entry_reason") + or out.get("reviewed_hold_minutes") + ) + return out + + +def _trade_row_to_dict(row: sqlite3.Row, overlay: dict | None = None) -> dict[str, Any]: + d = dict(row) + payload = {} + raw = d.pop("payload_json", None) + if raw: + try: + payload = json.loads(raw) + except (json.JSONDecodeError, TypeError): + payload = {} + out = {**payload, **{k: d[k] for k in d.keys() if k not in payload}} + for key in ( + "exchange_key", + "symbol", + "trade_id", + "direction", + "result", + "pnl_amount", + "opened_at", + "closed_at", + "opened_at_ms", + "closed_at_ms", + "monitor_type", + "entry_reason", + "exchange_turnover_usdt", + "exchange_commission_usdt", + "synced_at", + ): + if key in d and d[key] not in (None, ""): + out[key] = d[key] + ov = overlay or {} + out["behavior_tag"] = ov.get("behavior_tag") or "" + out["note"] = ov.get("note") or "" + out["trade_id"] = out.get("trade_id") or out.get("id") + ex_col = str(d.get("exchange_key") or "").strip().lower() + if ex_col: + out["exchange_key"] = ex_col + out.pop("account_exchange_key", None) + return _enrich_trade_display_fields(out) + + +def load_overlays( + exchange_key: str, + trade_ids: list[int] | None = None, + *, + db_path: Path | None = None, +) -> dict[int, dict[str, Any]]: + ex_k = (exchange_key or "").strip().lower() + conn = _connect(db_path) + try: + if trade_ids: + placeholders = ",".join("?" * len(trade_ids)) + rows = conn.execute( + f""" + SELECT exchange_key, trade_id, behavior_tag, note, updated_at + FROM trade_overlay + WHERE exchange_key=? AND trade_id IN ({placeholders}) + """, + (ex_k, *trade_ids), + ).fetchall() + else: + rows = conn.execute( + """ + SELECT exchange_key, trade_id, behavior_tag, note, updated_at + FROM trade_overlay WHERE exchange_key=? + """, + (ex_k,), + ).fetchall() + return { + int(r["trade_id"]): { + "behavior_tag": r["behavior_tag"] or "", + "note": r["note"] or "", + "updated_at": r["updated_at"], + } + for r in rows + } + finally: + conn.close() + + +def upsert_trade_overlay( + exchange_key: str, + trade_id: int, + *, + behavior_tag: str | None = None, + note: str | None = None, + db_path: Path | None = None, +) -> dict[str, Any]: + init_db(db_path) + ex_k = (exchange_key or "").strip().lower() + tid = int(trade_id) + tag = (behavior_tag or "").strip().lower() + if tag not in BEHAVIOR_TAGS: + tag = "" + note_text = (note or "").strip()[:2000] + now = _now_ms() + conn = _connect(db_path) + try: + conn.execute( + """ + INSERT INTO trade_overlay (exchange_key, trade_id, behavior_tag, note, updated_at) + VALUES (?,?,?,?,?) + ON CONFLICT(exchange_key, trade_id) DO UPDATE SET + behavior_tag=excluded.behavior_tag, + note=excluded.note, + updated_at=excluded.updated_at + """, + (ex_k, tid, tag, note_text, now), + ) + finally: + conn.close() + return {"exchange_key": ex_k, "trade_id": tid, "behavior_tag": tag, "note": note_text} + + +def list_symbol_rows( + *, + exchange_key: str = "", + filter_profit: bool = False, + filter_loss: bool = False, + filter_sick: bool = False, + filter_emotion: bool = False, + db_path: Path | None = None, +) -> list[dict[str, Any]]: + """一所一币一行汇总。""" + init_db(db_path) + conn = _connect(db_path) + try: + params: list[Any] = [] + where = "1=1" + ex_filter = (exchange_key or "").strip().lower() + if ex_filter: + where += " AND t.exchange_key=?" + params.append(ex_filter) + + rows = conn.execute( + f""" + SELECT t.exchange_key, t.symbol, + COUNT(*) AS trade_count, + SUM(CASE WHEN t.pnl_amount > 0.0001 THEN 1 ELSE 0 END) AS win_count, + SUM(CASE WHEN t.pnl_amount < -0.0001 THEN 1 ELSE 0 END) AS loss_count, + SUM(COALESCE(t.pnl_amount, 0)) AS total_pnl, + MIN(COALESCE(t.opened_at_ms, 0)) AS first_opened_ms, + MAX(COALESCE(t.closed_at_ms, 0)) AS last_closed_ms + FROM archive_trade_cache t + WHERE {where} + GROUP BY t.exchange_key, t.symbol + ORDER BY last_closed_ms DESC + """, + params, + ).fetchall() + + overlays_by_ex: dict[str, dict[int, dict]] = {} + out: list[dict[str, Any]] = [] + for r in rows: + ex_k = r["exchange_key"] + sym = r["symbol"] + if ex_k not in overlays_by_ex: + overlays_by_ex[ex_k] = load_overlays(ex_k, db_path=db_path) + + trade_rows = conn.execute( + """ + SELECT trade_id, pnl_amount FROM archive_trade_cache + WHERE exchange_key=? AND symbol=? + """, + (ex_k, sym), + ).fetchall() + has_profit = any(float(x["pnl_amount"] or 0) > 0.0001 for x in trade_rows) + has_loss = any(float(x["pnl_amount"] or 0) < -0.0001 for x in trade_rows) + has_sick = False + has_emotion = False + ov_map = overlays_by_ex.get(ex_k) or {} + for tr in trade_rows: + ov = ov_map.get(int(tr["trade_id"])) or {} + if ov.get("behavior_tag") == "sick": + has_sick = True + if ov.get("behavior_tag") == "emotion": + has_emotion = True + + if filter_profit and not has_profit: + continue + if filter_loss and not has_loss: + continue + if filter_sick and not has_sick: + continue + if filter_emotion and not has_emotion: + continue + + meta = conn.execute( + "SELECT seed_complete, last_kline_sync_ms FROM archive_meta WHERE exchange_key=? AND symbol=?", + (ex_k, sym), + ).fetchone() + + out.append( + { + "exchange_key": ex_k, + "symbol": sym, + "trade_count": int(r["trade_count"] or 0), + "win_count": int(r["win_count"] or 0), + "loss_count": int(r["loss_count"] or 0), + "total_pnl": round(float(r["total_pnl"] or 0), 4), + "first_opened_ms": int(r["first_opened_ms"] or 0) or None, + "last_closed_ms": int(r["last_closed_ms"] or 0) or None, + "seed_complete": bool(meta["seed_complete"]) if meta else False, + "last_kline_sync_ms": int(meta["last_kline_sync_ms"] or 0) if meta else None, + } + ) + return out + finally: + conn.close() + + +def load_symbol_trades( + exchange_key: str, + symbol: str, + *, + db_path: Path | None = None, +) -> list[dict[str, Any]]: + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + conn = _connect(db_path) + try: + rows = conn.execute( + """ + SELECT * FROM archive_trade_cache + WHERE exchange_key=? AND symbol=? + ORDER BY COALESCE(closed_at_ms, 0) DESC, trade_id DESC + """, + (ex_k, sym), + ).fetchall() + ids = [int(r["trade_id"]) for r in rows] + ov = load_overlays(ex_k, ids, db_path=db_path) + return [_trade_row_to_dict(r, ov.get(int(r["trade_id"]))) for r in rows] + finally: + conn.close() + + +def upsert_bars_5m( + exchange_key: str, + symbol: str, + bars: list[dict[str, Any]], + *, + db_path: Path | None = None, +) -> int: + init_db(db_path) + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + now = _now_ms() + n = 0 + conn = _connect(db_path) + try: + for b in bars or []: + try: + conn.execute( + """ + INSERT INTO archive_bars_5m ( + exchange_key, symbol, open_time_ms, open, high, low, close, volume, updated_at + ) VALUES (?,?,?,?,?,?,?,?,?) + ON CONFLICT(exchange_key, symbol, open_time_ms) DO UPDATE SET + open=excluded.open, + high=excluded.high, + low=excluded.low, + close=excluded.close, + volume=excluded.volume, + updated_at=excluded.updated_at + """, + ( + ex_k, + sym, + int(b["open_time_ms"]), + float(b["open"]), + float(b["high"]), + float(b["low"]), + float(b["close"]), + float(b.get("volume") or 0), + now, + ), + ) + n += 1 + except (KeyError, TypeError, ValueError): + continue + finally: + conn.close() + return n + + +def load_bars_5m_range( + exchange_key: str, + symbol: str, + start_ms: int, + end_ms: int, + *, + db_path: Path | None = None, +) -> list[dict[str, Any]]: + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + conn = _connect(db_path) + try: + rows = conn.execute( + """ + SELECT open_time_ms, open, high, low, close, volume + FROM archive_bars_5m + WHERE exchange_key=? AND symbol=? + AND open_time_ms >= ? AND open_time_ms <= ? + ORDER BY open_time_ms ASC + """, + (ex_k, sym, int(start_ms), int(end_ms)), + ).fetchall() + return [ + { + "open_time_ms": int(r["open_time_ms"]), + "open": float(r["open"]), + "high": float(r["high"]), + "low": float(r["low"]), + "close": float(r["close"]), + "volume": float(r["volume"] or 0), + } + for r in rows + ] + finally: + conn.close() + + +def _to_candles(bars: list[dict[str, Any]]) -> list[dict[str, Any]]: + out = [] + for b in bars or []: + try: + out.append( + { + "time": int(b["open_time_ms"] // 1000), + "open": float(b["open"]), + "high": float(b["high"]), + "low": float(b["low"]), + "close": float(b["close"]), + "volume": float(b.get("volume") or 0), + } + ) + except (KeyError, TypeError, ValueError): + continue + return out + + +def _snap_to_bar_grid(ts_ms: int, origin_ms: int, step_ms: int) -> int: + step = max(1, int(step_ms)) + origin = int(origin_ms) + if ts_ms <= origin: + return origin + idx = (int(ts_ms) - origin + step - 1) // step + return origin + idx * step + + +def _fill_missing_bars( + bars: list[dict[str, Any]], + period_ms: int, + start_ms: int, + end_ms: int, +) -> list[dict[str, Any]]: + """5m 缺口用上一根收盘价填平,保证聚合后 K 线时间轴连续。""" + by_ts: dict[int, dict[str, Any]] = {} + for b in bars or []: + try: + by_ts[int(b["open_time_ms"])] = b + except (KeyError, TypeError, ValueError): + continue + if not by_ts: + return [] + keys = sorted(by_ts.keys()) + step_ms = max(1, int(period_ms)) + origin = keys[0] + aligned_start = _snap_to_bar_grid(int(start_ms), origin, step_ms) + aligned_end = max(int(end_ms), keys[-1]) + out: list[dict[str, Any]] = [] + last: dict[str, Any] | None = None + for ts_key in keys: + if ts_key <= aligned_start: + last = by_ts[ts_key] + ts = aligned_start + while ts <= aligned_end: + cur = by_ts.get(ts) + if cur is not None: + last = cur + out.append(cur) + elif last is not None: + c = float(last["close"]) + out.append( + { + "open_time_ms": ts, + "open": c, + "high": c, + "low": c, + "close": c, + "volume": 0.0, + "filled": True, + } + ) + ts += step_ms + return out + + +def _archive_earliest_bar_ms( + exchange_key: str, + symbol: str, + *, + db_path: Path | None = None, +) -> int | None: + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + conn = _connect(db_path) + try: + row = conn.execute( + "SELECT MIN(open_time_ms) AS mn FROM archive_bars_5m WHERE exchange_key=? AND symbol=?", + (ex_k, sym), + ).fetchone() + if row and row["mn"] is not None: + return int(row["mn"]) + finally: + conn.close() + return None + + +def _trim_bars_for_cap( + bars: list[dict[str, Any]], + *, + end_ms: int, + max_n: int, +) -> list[dict[str, Any]]: + """超长时优先保留到平仓,再从最古老端截断。""" + if len(bars) <= max_n: + return bars + cut_end = len(bars) + for i in range(len(bars) - 1, -1, -1): + if int(bars[i]["open_time_ms"]) <= int(end_ms): + cut_end = i + 1 + break + essential = bars[:cut_end] + if len(essential) <= max_n: + return essential + return essential[len(essential) - max_n :] + + +def resolve_archive_chart( + exchange_key: str, + symbol: str, + timeframe: str = ARCHIVE_DEFAULT_TIMEFRAME, + *, + anchor_ms: int | None = None, + opened_ms: int | None = None, + closed_ms: int | None = None, + mode: str = "hold", + bars: int = ARCHIVE_VISIBLE_BARS_DEFAULT, + range_mode: str = "window", + db_path: Path | None = None, +) -> dict[str, Any]: + """从永久 5m 库聚合出档案 K 线视窗。 + + range_mode=history:建档起点 → 平仓(不含「到现在」),供拖动/缩放查看建仓前全局形态。 + """ + tf = normalize_chart_timeframe(timeframe, default=ARCHIVE_DEFAULT_TIMEFRAME) + if tf not in ARCHIVE_TIMEFRAMES: + return {"ok": False, "msg": f"档案仅支持 {', '.join(sorted(ARCHIVE_TIMEFRAMES))}"} + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + if not ex_k or not sym: + return {"ok": False, "msg": "缺少 exchange_key 或 symbol"} + + period = TIMEFRAME_MS[tf] + period_5m = TIMEFRAME_MS["5m"] + hold_open = int(opened_ms) if opened_ms else None + hold_close = int(closed_ms) if closed_ms else None + rm = (range_mode or "window").strip().lower() + if hold_open and hold_close and hold_close >= hold_open and rm == "history": + seed_back = max(0, hold_open - ARCHIVE_SEED_LOOKBACK_DAYS * 86400000) + earliest = _archive_earliest_bar_ms(ex_k, sym, db_path=db_path) + if earliest is not None: + start_ms = min(earliest, seed_back) + else: + start_ms = seed_back + end_ms = hold_close + max(period * 16, period_5m * 8) + anchor = hold_close if (mode or "hold").strip().lower() != "entry" else hold_open + elif hold_open and hold_close and hold_close >= hold_open: + hold_len = hold_close - hold_open + pad = max(period * 24, hold_len // 3, period_5m * 12) + start_ms = max(0, hold_open - pad) + end_ms = hold_close + pad + anchor = hold_close if (mode or "hold").strip().lower() != "entry" else hold_open + else: + visible = max(50, min(int(bars or ARCHIVE_VISIBLE_BARS_DEFAULT), 500)) + anchor = int(anchor_ms) if anchor_ms else _now_ms() + half = visible // 2 + start_ms = max(0, anchor - half * period) + end_ms = anchor + half * period + + raw_5m = load_bars_5m_range( + ex_k, + sym, + start_ms - period_5m * 6, + end_ms + period_5m * 6, + db_path=db_path, + ) + if not raw_5m: + return {"ok": False, "msg": "档案库暂无 K 线,请等待同步或手动刷新"} + + filled_5m = _fill_missing_bars(raw_5m, period_5m, start_ms - period_5m * 2, end_ms + period_5m * 2) + + if tf == "5m": + merged = [b for b in filled_5m if start_ms <= int(b["open_time_ms"]) <= end_ms] + else: + agg = aggregate_ohlcv_bars(filled_5m, tf) + merged = [b for b in agg if start_ms <= int(b["open_time_ms"]) <= end_ms] + + max_n = ARCHIVE_MAX_CANDLES.get(tf, 2000) + if rm == "history" and merged and len(merged) > max_n: + merged = merged[:max_n] + + candles = _to_candles(merged) + if not candles: + return {"ok": False, "msg": "视窗内无 K 线"} + + ex_sym = normalize_perpetual_symbol(sym) + return { + "ok": True, + "exchange_key": ex_k, + "symbol": sym, + "exchange_symbol": ex_sym, + "market_type": "swap", + "timeframe": tf, + "mode": (mode or "hold").strip().lower(), + "range_mode": rm, + "anchor_ms": anchor, + "opened_ms": hold_open, + "closed_ms": hold_close, + "window_start_ms": start_ms, + "window_end_ms": end_ms, + "candles": candles, + "bar_count": len(candles), + "gaps_filled": sum(1 for b in filled_5m if b.get("filled")), + } + + +def _ensure_meta( + exchange_key: str, + symbol: str, + first_opened_ms: int | None, + *, + db_path: Path | None = None, +) -> None: + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + now = _now_ms() + conn = _connect(db_path) + try: + row = conn.execute( + "SELECT first_trade_opened_ms FROM archive_meta WHERE exchange_key=? AND symbol=?", + (ex_k, sym), + ).fetchone() + if row: + if first_opened_ms and ( + not row["first_trade_opened_ms"] + or int(first_opened_ms) < int(row["first_trade_opened_ms"]) + ): + conn.execute( + """ + UPDATE archive_meta SET first_trade_opened_ms=? + WHERE exchange_key=? AND symbol=? + """, + (int(first_opened_ms), ex_k, sym), + ) + return + conn.execute( + """ + INSERT INTO archive_meta ( + exchange_key, symbol, first_trade_opened_ms, + archive_started_at, last_kline_sync_ms, last_trade_sync_ms, seed_complete + ) VALUES (?,?,?,?,?,?,0) + """, + (ex_k, sym, int(first_opened_ms) if first_opened_ms else None, now, None, None), + ) + finally: + conn.close() + + +def _mark_meta_sync( + exchange_key: str, + symbol: str, + *, + kline_ms: int | None = None, + trade_ms: int | None = None, + seed_complete: bool | None = None, + db_path: Path | None = None, +) -> None: + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + conn = _connect(db_path) + try: + sets = [] + params: list[Any] = [] + if kline_ms is not None: + sets.append("last_kline_sync_ms=?") + params.append(int(kline_ms)) + if trade_ms is not None: + sets.append("last_trade_sync_ms=?") + params.append(int(trade_ms)) + if seed_complete is not None: + sets.append("seed_complete=?") + params.append(1 if seed_complete else 0) + if not sets: + return + params.extend([ex_k, sym]) + conn.execute( + f"UPDATE archive_meta SET {', '.join(sets)} WHERE exchange_key=? AND symbol=?", + params, + ) + finally: + conn.close() + + +def fetch_remote_5m_range( + remote_fetch: Callable[..., dict[str, Any]], + symbol: str, + start_ms: int, + end_ms: int, +) -> list[dict[str, Any]]: + """经实例 /api/hub/ohlcv 分页拉取 5m。""" + period = TIMEFRAME_MS["5m"] + since = max(0, int(start_ms)) + end = int(end_ms) + merged: dict[int, dict[str, Any]] = {} + guard = 0 + while since < end and guard < 120: + guard += 1 + remote = remote_fetch(symbol=symbol, timeframe="5m", since_ms=since, limit=500) + if not remote.get("ok"): + break + batch = remote.get("bars") or [] + if not batch: + break + for b in batch: + try: + ts = int(b["open_time_ms"]) + merged[ts] = b + except (KeyError, TypeError, ValueError): + continue + last_ts = max(int(b["open_time_ms"]) for b in batch) + next_since = last_ts + period + if next_since <= since: + break + since = next_since + if last_ts >= end: + break + return [merged[k] for k in sorted(merged.keys()) if start_ms <= k <= end] + + +def seed_symbol_archive( + exchange_key: str, + symbol: str, + first_opened_ms: int, + remote_fetch: Callable[..., dict[str, Any]], + *, + db_path: Path | None = None, +) -> dict[str, Any]: + """建档:最早开仓向前 30 天 5m 种子。""" + init_db(db_path) + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + anchor = int(first_opened_ms) + start_ms = max(0, anchor - ARCHIVE_SEED_LOOKBACK_DAYS * 86400000) + end_ms = _now_ms() + _ensure_meta(ex_k, sym, anchor, db_path=db_path) + bars = fetch_remote_5m_range(remote_fetch, sym, start_ms, end_ms) + n = upsert_bars_5m(ex_k, sym, bars, db_path=db_path) + now = _now_ms() + _mark_meta_sync(ex_k, sym, kline_ms=now, seed_complete=True, db_path=db_path) + return {"ok": True, "seed_bars": n, "start_ms": start_ms, "end_ms": end_ms} + + +def sync_symbol_klines_incremental( + exchange_key: str, + symbol: str, + remote_fetch: Callable[..., dict[str, Any]], + *, + db_path: Path | None = None, +) -> dict[str, Any]: + """增量补 5m 至当前。""" + init_db(db_path) + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + conn = _connect(db_path) + try: + row = conn.execute( + "SELECT MAX(open_time_ms) AS mx FROM archive_bars_5m WHERE exchange_key=? AND symbol=?", + (ex_k, sym), + ).fetchone() + last_bar = int(row["mx"]) if row and row["mx"] else None + finally: + conn.close() + + period = TIMEFRAME_MS["5m"] + start_ms = max(0, (last_bar + period) if last_bar else 0) + end_ms = _now_ms() + if start_ms >= end_ms - period: + return {"ok": True, "appended": 0, "skipped": True} + bars = fetch_remote_5m_range(remote_fetch, sym, start_ms, end_ms) + n = upsert_bars_5m(ex_k, sym, bars, db_path=db_path) + now = _now_ms() + _mark_meta_sync(ex_k, sym, kline_ms=now, db_path=db_path) + return {"ok": True, "appended": n, "start_ms": start_ms, "end_ms": end_ms} + + +def sync_exchange_symbol_archives( + exchange_key: str, + trades: list[dict[str, Any]], + remote_fetch: Callable[..., dict[str, Any]], + *, + db_path: Path | None = None, +) -> dict[str, Any]: + """同步单所:交易缓存 + 各币种 K 线种子/增量。""" + ex_k = (exchange_key or "").strip().lower() + cache_stats = upsert_trades_cache(ex_k, trades, db_path=db_path, prune_missing=True) + + by_sym: dict[str, int] = {} + for t in trades or []: + sym = (t.get("symbol") or "").strip().upper() + if not sym: + continue + oms = t.get("opened_at_ms") or _parse_dt_ms(t.get("opened_at")) + if oms: + cur = by_sym.get(sym) + if cur is None or int(oms) < cur: + by_sym[sym] = int(oms) + + seeded = 0 + appended = 0 + for sym, first_ms in by_sym.items(): + _ensure_meta(ex_k, sym, first_ms, db_path=db_path) + conn = _connect(db_path) + try: + meta = conn.execute( + "SELECT seed_complete FROM archive_meta WHERE exchange_key=? AND symbol=?", + (ex_k, sym), + ).fetchone() + finally: + conn.close() + if not meta or not int(meta["seed_complete"] or 0): + r = seed_symbol_archive(ex_k, sym, first_ms, remote_fetch, db_path=db_path) + seeded += int(r.get("seed_bars") or 0) + else: + r = sync_symbol_klines_incremental(ex_k, sym, remote_fetch, db_path=db_path) + appended += int(r.get("appended") or 0) + + return { + "ok": True, + "exchange_key": ex_k, + "symbols": len(by_sym), + "trades_upserted": int(cache_stats.get("upserted") or 0), + "trades_removed": int(cache_stats.get("removed") or 0), + "seed_bars": seeded, + "appended_bars": appended, + "trades": len(trades or []), + } + + +def ms_to_trading_day( + ms: int | None, + *, + reset_hour: int = TRADING_DAY_RESET_HOUR, + tz: ZoneInfo = CHART_DISPLAY_TZ, +) -> str | None: + if ms is None: + return None + try: + dt = datetime.fromtimestamp(int(ms) / 1000.0, tz=timezone.utc).astimezone(tz) + except (TypeError, ValueError, OSError): + return None + if dt.hour < reset_hour: + dt = dt - timedelta(days=1) + return dt.strftime("%Y-%m-%d") + + +def today_trading_day(*, reset_hour: int = TRADING_DAY_RESET_HOUR) -> str: + return ms_to_trading_day(_now_ms(), reset_hour=reset_hour) or datetime.now( + CHART_DISPLAY_TZ + ).strftime("%Y-%m-%d") + + +def trading_day_bounds_ms( + trading_day: str, + *, + reset_hour: int = TRADING_DAY_RESET_HOUR, + tz: ZoneInfo = CHART_DISPLAY_TZ, +) -> tuple[int, int]: + day = datetime.strptime((trading_day or "").strip()[:10], "%Y-%m-%d") + start = day.replace(hour=reset_hour, minute=0, second=0, microsecond=0, tzinfo=tz) + end = start + timedelta(days=1) + return int(start.timestamp() * 1000), int(end.timestamp() * 1000) + + +def resolve_period_bounds( + *, + period: str = "", + trading_day: str = "", + date_from: str = "", + date_to: str = "", + reset_hour: int = TRADING_DAY_RESET_HOUR, +) -> tuple[int, int, str, str, str]: + """返回 (start_ms, end_ms, date_from, date_to, period_label)。""" + td = today_trading_day(reset_hour=reset_hour) + p = (period or "today").strip().lower() + if p in ("day", "today", ""): + d = (trading_day or "").strip()[:10] or td + start_ms, end_ms = trading_day_bounds_ms(d, reset_hour=reset_hour) + return start_ms, end_ms, d, d, f"本日 {d}" + if p == "week": + day_dt = datetime.strptime(td, "%Y-%m-%d") + monday = day_dt - timedelta(days=day_dt.weekday()) + df = monday.strftime("%Y-%m-%d") + start_ms, _ = trading_day_bounds_ms(df, reset_hour=reset_hour) + _, end_ms = trading_day_bounds_ms(td, reset_hour=reset_hour) + return start_ms, end_ms, df, td, f"本周 {df}~{td}" + if p == "month": + day_dt = datetime.strptime(td, "%Y-%m-%d") + first = day_dt.replace(day=1) + df = first.strftime("%Y-%m-%d") + start_ms, _ = trading_day_bounds_ms(df, reset_hour=reset_hour) + _, end_ms = trading_day_bounds_ms(td, reset_hour=reset_hour) + return start_ms, end_ms, df, td, f"本月 {df}~{td}" + if p == "range": + df = (date_from or "").strip()[:10] or td + dt = (date_to or "").strip()[:10] or df + if df > dt: + df, dt = dt, df + start_ms, _ = trading_day_bounds_ms(df, reset_hour=reset_hour) + _, end_ms = trading_day_bounds_ms(dt, reset_hour=reset_hour) + label = f"区间 {df}~{dt}" if df != dt else f"区间 {df}" + return start_ms, end_ms, df, dt, label + d = (trading_day or "").strip()[:10] or td + start_ms, end_ms = trading_day_bounds_ms(d, reset_hour=reset_hour) + return start_ms, end_ms, d, d, f"本日 {d}" + + +def _pnl_side(pnl: float) -> str: + if pnl > 0.0001: + return "win" + if pnl < -0.0001: + return "loss" + return "flat" + + +def _empty_pnl_bucket() -> dict[str, Any]: + return { + "open_count": 0, + "sick_count": 0, + "pnl_total": 0.0, + "pnl_ex_sick": 0.0, + "turnover_total": 0.0, + "commission_total": 0.0, + "win_count": 0, + "loss_count": 0, + "avg_win": None, + "avg_loss": None, + "max_win": None, + "max_loss": None, + } + + +def _finalize_pnl_bucket(bucket: dict[str, Any]) -> None: + wins = bucket.pop("_wins", []) + losses = bucket.pop("_losses", []) + open_count = int(bucket.get("open_count") or 0) + win_count = len(wins) + bucket["win_count"] = win_count + bucket["loss_count"] = len(losses) + bucket["avg_win"] = round(sum(wins) / len(wins), 4) if wins else None + avg_loss = round(sum(losses) / len(losses), 4) if losses else None + bucket["avg_loss"] = avg_loss + bucket["max_win"] = round(max(wins), 4) if wins else None + bucket["max_loss"] = round(min(losses), 4) if losses else None + bucket["pnl_total"] = round(float(bucket.get("pnl_total") or 0), 4) + bucket["pnl_ex_sick"] = round(float(bucket.get("pnl_ex_sick") or 0), 4) + bucket["turnover_total"] = round(float(bucket.get("turnover_total") or 0), 4) + bucket["commission_total"] = round(float(bucket.get("commission_total") or 0), 4) + bucket["win_rate"] = round(win_count / open_count * 100, 1) if open_count else None + avg_win = bucket["avg_win"] + if avg_win is not None and avg_loss is not None and avg_loss != 0: + bucket["profit_loss_ratio"] = round(avg_win / abs(avg_loss), 2) + else: + bucket["profit_loss_ratio"] = None + + +def _accumulate_trade_stat( + bucket: dict[str, Any], + *, + pnl: float, + is_sick: bool, + turnover: float = 0.0, + commission: float = 0.0, +) -> None: + bucket["open_count"] += 1 + bucket["pnl_total"] += pnl + bucket["turnover_total"] += turnover + bucket["commission_total"] += commission + if is_sick: + bucket["sick_count"] += 1 + else: + bucket["pnl_ex_sick"] += pnl + side = _pnl_side(pnl) + if side == "win": + bucket.setdefault("_wins", []).append(pnl) + elif side == "loss": + bucket.setdefault("_losses", []).append(pnl) + + +def _compute_period_stats(trade_rows: list[dict[str, Any]]) -> dict[str, Any]: + total_bucket = _empty_pnl_bucket() + by_ex: dict[str, dict[str, Any]] = {} + for td_row in trade_rows: + ex = str(td_row.get("exchange_key") or "?") + pnl = float(td_row.get("pnl_amount") or 0) + tag = str(td_row.get("behavior_tag") or "") + is_sick = tag == "sick" + turnover = float(td_row.get("exchange_turnover_usdt") or 0) + commission = float(td_row.get("exchange_commission_usdt") or 0) + _accumulate_trade_stat( + total_bucket, pnl=pnl, is_sick=is_sick, turnover=turnover, commission=commission + ) + if ex not in by_ex: + by_ex[ex] = _empty_pnl_bucket() + _accumulate_trade_stat( + by_ex[ex], pnl=pnl, is_sick=is_sick, turnover=turnover, commission=commission + ) + _finalize_pnl_bucket(total_bucket) + for ex in by_ex: + _finalize_pnl_bucket(by_ex[ex]) + total = int(total_bucket["open_count"] or 0) + sick = int(total_bucket["sick_count"] or 0) + sick_pct = round(sick / total * 100, 1) if total else 0.0 + return { + "open_count": total, + "sick_count": sick, + "sick_pct": sick_pct, + "pnl_total": total_bucket["pnl_total"], + "pnl_ex_sick": total_bucket["pnl_ex_sick"], + "win_count": total_bucket["win_count"], + "loss_count": total_bucket["loss_count"], + "avg_win": total_bucket["avg_win"], + "avg_loss": total_bucket["avg_loss"], + "max_win": total_bucket["max_win"], + "max_loss": total_bucket["max_loss"], + "win_rate": total_bucket["win_rate"], + "profit_loss_ratio": total_bucket["profit_loss_ratio"], + "turnover_total": total_bucket["turnover_total"], + "commission_total": total_bucket["commission_total"], + "by_exchange": by_ex, + } + + +def list_review_quotes(*, db_path: Path | None = None) -> list[dict[str, Any]]: + init_db(db_path) + conn = _connect(db_path) + try: + rows = conn.execute( + """ + SELECT id, quote_date, content, created_at, updated_at + FROM archive_review_quotes + ORDER BY quote_date DESC + LIMIT ? + """, + (ARCHIVE_QUOTES_MAX,), + ).fetchall() + return [dict(r) for r in rows] + finally: + conn.close() + + +def create_review_quote( + quote_date: str, + content: str, + *, + db_path: Path | None = None, +) -> dict[str, Any]: + init_db(db_path) + qd = (quote_date or "").strip()[:10] + if not qd: + raise ValueError("缺少 quote_date") + text = (content or "").strip() + if not text: + raise ValueError("语录内容不能为空") + if len(text) > ARCHIVE_QUOTE_MAX_LEN: + raise ValueError(f"语录最长 {ARCHIVE_QUOTE_MAX_LEN} 字") + conn = _connect(db_path) + try: + cnt = conn.execute("SELECT COUNT(*) AS c FROM archive_review_quotes").fetchone() + if int(cnt["c"] or 0) >= ARCHIVE_QUOTES_MAX: + raise ValueError(f"复盘语录最多保存 {ARCHIVE_QUOTES_MAX} 条") + now = _now_ms() + try: + cur = conn.execute( + """ + INSERT INTO archive_review_quotes (quote_date, content, created_at, updated_at) + VALUES (?,?,?,?) + """, + (qd, text, now, now), + ) + except sqlite3.IntegrityError as e: + raise ValueError("该日期已有语录,请展开编辑") from e + rid = int(cur.lastrowid) + row = conn.execute( + "SELECT id, quote_date, content, created_at, updated_at FROM archive_review_quotes WHERE id=?", + (rid,), + ).fetchone() + return dict(row) + finally: + conn.close() + + +def update_review_quote( + quote_id: int, + *, + quote_date: str | None = None, + content: str | None = None, + db_path: Path | None = None, +) -> dict[str, Any] | None: + init_db(db_path) + conn = _connect(db_path) + try: + row = conn.execute( + "SELECT id, quote_date, content FROM archive_review_quotes WHERE id=?", + (int(quote_id),), + ).fetchone() + if not row: + return None + qd = (quote_date or row["quote_date"] or "").strip()[:10] + text = (content if content is not None else row["content"] or "").strip() + if not qd or not text: + raise ValueError("日期与内容均不能为空") + if len(text) > ARCHIVE_QUOTE_MAX_LEN: + raise ValueError(f"语录最长 {ARCHIVE_QUOTE_MAX_LEN} 字") + now = _now_ms() + conn.execute( + """ + UPDATE archive_review_quotes + SET quote_date=?, content=?, updated_at=? + WHERE id=? + """, + (qd, text, now, int(quote_id)), + ) + out = conn.execute( + "SELECT id, quote_date, content, created_at, updated_at FROM archive_review_quotes WHERE id=?", + (int(quote_id),), + ).fetchone() + return dict(out) if out else None + finally: + conn.close() + + +def delete_review_quote(quote_id: int, *, db_path: Path | None = None) -> bool: + init_db(db_path) + conn = _connect(db_path) + try: + cur = conn.execute( + "DELETE FROM archive_review_quotes WHERE id=?", + (int(quote_id),), + ) + return int(cur.rowcount or 0) > 0 + finally: + conn.close() + + +def list_daily_trades( + trading_day: str = "", + *, + period: str = "", + date_from: str = "", + date_to: str = "", + exchange_key: str = "", + filter_profit: bool = False, + filter_loss: bool = False, + filter_sick: bool = False, + search: str = "", + db_path: Path | None = None, +) -> dict[str, Any]: + """按日期区间列出平仓记录(本日/本周/本月/自选,以平仓时间计),含犯病与盈亏统计。""" + init_db(db_path) + p = (period or "today").strip().lower() or "today" + start_ms, end_ms, df, dt, period_label = resolve_period_bounds( + period=p, + trading_day=trading_day, + date_from=date_from, + date_to=date_to, + ) + ex_filter = (exchange_key or "").strip().lower() + conn = _connect(db_path) + try: + params: list[Any] = [start_ms, end_ms] + where = "closed_at_ms IS NOT NULL AND closed_at_ms >= ? AND closed_at_ms < ?" + if ex_filter: + where += " AND exchange_key=?" + params.append(ex_filter) + rows = conn.execute( + f""" + SELECT * FROM archive_trade_cache + WHERE {where} + ORDER BY closed_at_ms DESC, trade_id DESC + """, + params, + ).fetchall() + overlays_by_ex: dict[str, dict[int, dict]] = {} + trades: list[dict[str, Any]] = [] + q = (search or "").strip().lower() + for r in rows: + ex_k = r["exchange_key"] + if ex_k not in overlays_by_ex: + overlays_by_ex[ex_k] = load_overlays(ex_k, db_path=db_path) + td_row = _trade_row_to_dict(r, overlays_by_ex[ex_k].get(int(r["trade_id"]))) + pnl = float(td_row.get("pnl_amount") or 0) + tag = td_row.get("behavior_tag") or "" + if filter_profit and pnl <= 0.0001: + continue + if filter_loss and pnl >= -0.0001: + continue + if filter_sick and tag != "sick": + continue + if q: + blob = " ".join( + str(td_row.get(k) or "") + for k in ( + "symbol", + "exchange_key", + "direction", + "result", + "note", + "monitor_type", + "entry_reason", + ) + ).lower() + if q not in blob: + continue + trades.append(td_row) + return { + "period": p, + "period_label": period_label, + "trading_day": dt, + "date_from": df, + "date_to": dt, + "trades": trades, + "stats": _compute_period_stats(trades), + } + finally: + conn.close() + + +def list_archive_calendar( + year: int, + month: int, + *, + exchange_key: str = "", + db_path: Path | None = None, + reset_hour: int = TRADING_DAY_RESET_HOUR, +) -> dict[str, Any]: + """按月返回每个交易日的盈亏、笔数、犯病标记(08:00 切日)。""" + init_db(db_path) + y = int(year) + m = int(month) + if m < 1 or m > 12: + raise ValueError("month 无效") + first = f"{y:04d}-{m:02d}-01" + if m == 12: + next_first = datetime(y + 1, 1, 1) + else: + next_first = datetime(y, m + 1, 1) + last = (next_first - timedelta(days=1)).strftime("%Y-%m-%d") + start_ms, _ = trading_day_bounds_ms(first, reset_hour=reset_hour) + _, end_ms = trading_day_bounds_ms(last, reset_hour=reset_hour) + ex_filter = (exchange_key or "").strip().lower() + conn = _connect(db_path) + try: + params: list[Any] = [start_ms, end_ms] + where = "closed_at_ms IS NOT NULL AND closed_at_ms >= ? AND closed_at_ms < ?" + if ex_filter: + where += " AND exchange_key=?" + params.append(ex_filter) + rows = conn.execute( + f"SELECT * FROM archive_trade_cache WHERE {where}", + params, + ).fetchall() + overlays_by_ex: dict[str, dict[int, dict]] = {} + days: dict[str, dict[str, Any]] = {} + for r in rows: + ex_k = r["exchange_key"] + if ex_k not in overlays_by_ex: + overlays_by_ex[ex_k] = load_overlays(ex_k, db_path=db_path) + td_row = _trade_row_to_dict(r, overlays_by_ex[ex_k].get(int(r["trade_id"]))) + closed_ms = td_row.get("closed_at_ms") or _parse_dt_ms(td_row.get("closed_at")) + if not closed_ms: + continue + day = ms_to_trading_day(int(closed_ms), reset_hour=reset_hour) + if not day: + continue + if day < first or day > last: + continue + bucket = days.setdefault( + day, + { + "trading_day": day, + "open_count": 0, + "sick_count": 0, + "pnl_total": 0.0, + "turnover_total": 0.0, + "commission_total": 0.0, + "has_sick": False, + }, + ) + pnl = float(td_row.get("pnl_amount") or 0) + tag = str(td_row.get("behavior_tag") or "") + is_sick = tag == "sick" + bucket["open_count"] += 1 + bucket["pnl_total"] += pnl + bucket["turnover_total"] += float(td_row.get("exchange_turnover_usdt") or 0) + bucket["commission_total"] += float(td_row.get("exchange_commission_usdt") or 0) + if is_sick: + bucket["sick_count"] += 1 + bucket["has_sick"] = True + for d in days.values(): + d["pnl_total"] = round(float(d["pnl_total"]), 4) + d["turnover_total"] = round(float(d["turnover_total"]), 4) + d["commission_total"] = round(float(d["commission_total"]), 4) + month_pnl = sum(float(d["pnl_total"]) for d in days.values()) + month_count = sum(int(d["open_count"]) for d in days.values()) + return { + "year": y, + "month": m, + "date_from": first, + "date_to": last, + "days": days, + "month_pnl_total": round(month_pnl, 4), + "month_open_count": month_count, + } + finally: + conn.close() diff --git a/hub_trades_lib.py b/lib/hub/hub_trades_lib.py similarity index 96% rename from hub_trades_lib.py rename to lib/hub/hub_trades_lib.py index c25f0b9..292cda3 100644 --- a/hub_trades_lib.py +++ b/lib/hub/hub_trades_lib.py @@ -1,638 +1,638 @@ -"""各实例当日平仓记录查询(供 hub_bridge /api/hub/trades/today 与中控 AI 聚合)。""" -from __future__ import annotations - -from datetime import datetime, timedelta -from typing import Any, Callable, Optional - -from strategy_trade_labels import ( - MONITOR_TYPE_ROLL, - MONITOR_TYPE_TREND_PULLBACK, - entry_reason_for_monitor_type, -) -from time_close_lib import TIME_CLOSE_RESULT - -TRADE_COMPLETED_RESULTS = ( - "止盈", - "止损", - "保本止盈", - "移动止盈", - "手动平仓", - "强制清仓", - "外部平仓", - TIME_CLOSE_RESULT, -) - - -def trading_day_from_dt(dt: datetime, reset_hour: int = 8) -> str: - """与实例 get_trading_day 一致:小时 < reset_hour 归属上一日历日。""" - if dt.hour < reset_hour: - dt = dt - timedelta(days=1) - return dt.strftime("%Y-%m-%d") - - -def current_trading_day(*, now: datetime | None = None, reset_hour: int = 8) -> str: - return trading_day_from_dt(now or datetime.now(), reset_hour) - - -def parse_dt_for_trading_day(raw: Any) -> datetime | None: - if raw is None: - return None - s = str(raw).strip().replace("Z", "").replace("T", " ") - if not s: - return None - for fmt, ln in (("%Y-%m-%d %H:%M:%S", 19), ("%Y-%m-%d %H:%M", 16), ("%Y-%m-%d", 10)): - try: - return datetime.strptime(s[:ln], fmt) - except ValueError: - continue - return None - - -def trading_day_window_bounds(trading_day: str, reset_hour: int = 8) -> tuple[str, str]: - """交易日 [reset_hour, 次日 reset_hour) 对应的北京时间字符串区间(闭区间)。""" - day = datetime.strptime((trading_day or "").strip()[:10], "%Y-%m-%d") - start = day.replace(hour=reset_hour, minute=0, second=0, microsecond=0) - end = start + timedelta(days=1) - timedelta(seconds=1) - return start.strftime("%Y-%m-%d %H:%M:%S"), end.strftime("%Y-%m-%d %H:%M:%S") - - -def _row_dict(row, row_to_dict: Optional[Callable] = None) -> dict: - if row is None: - return {} - if row_to_dict: - try: - return dict(row_to_dict(row)) - except Exception: - pass - try: - keys = row.keys() if hasattr(row, "keys") else () - if keys: - return {k: row[k] for k in keys} - except Exception: - pass - try: - return dict(row) - except Exception: - return {} - - -def _effective_field(d: dict, reviewed_key: str, base_key: str, default: Any = None) -> Any: - rv = d.get(reviewed_key) - if rv is not None and str(rv).strip() != "": - return rv - bv = d.get(base_key) - if bv is not None and str(bv).strip() != "": - return bv - return default - - -def format_hold_minutes(minutes: Any) -> str: - try: - total = int(minutes or 0) - except (TypeError, ValueError): - return "0分钟" - if total <= 0: - return "0分钟" - hours = total // 60 - mins = total % 60 - if hours: - return f"{hours}小时{mins}分钟" - return f"{mins}分钟" - - -def _normalize_monitor_type_label(raw: Any) -> str: - mt = str(raw or "").strip() - if mt in ("trend_pullback", "trend"): - return MONITOR_TYPE_TREND_PULLBACK - if mt in ("roll",): - return MONITOR_TYPE_ROLL - return mt - - -def effective_entry_type(d: dict) -> str: - """复盘开仓类型优先,与实例交易记录 effective_entry_reason 一致。""" - er = _effective_field(d, "reviewed_entry_reason", "entry_reason") - if er is not None and str(er).strip(): - return str(er).strip() - mt = _normalize_monitor_type_label(d.get("monitor_type")) - er2 = entry_reason_for_monitor_type(mt) - if er2: - return er2 - kst = str(d.get("key_signal_type") or "").strip() - if kst: - return kst - legacy = str(d.get("entry_type") or "").strip() - if legacy and legacy not in ("trend_pullback", "roll", "trend"): - return _normalize_monitor_type_label(legacy) or legacy - return mt - - -def display_entry_type_label(d: dict) -> str: - """档案/列表展示用开仓类型(不回落为「下单监控」若已有复盘或建档类型)。""" - label = effective_entry_type(d).strip() - if not label: - return "—" - return _normalize_monitor_type_label(label) or label - - -def effective_hold_minutes( - d: dict, - *, - opened_ms: int | None = None, - closed_ms: int | None = None, -) -> int: - hm = _effective_field(d, "reviewed_hold_minutes", "hold_minutes") - if hm is not None and str(hm).strip() != "": - try: - return max(0, int(hm)) - except (TypeError, ValueError): - pass - hs = _effective_field(d, "reviewed_hold_seconds", "hold_seconds") - if hs is not None and str(hs).strip() != "": - try: - return max(0, int(int(hs) // 60)) - except (TypeError, ValueError): - pass - oms = opened_ms if opened_ms is not None else d.get("opened_at_ms") - cms = closed_ms if closed_ms is not None else d.get("closed_at_ms") - try: - oms_i = int(oms) if oms not in (None, "") else None - cms_i = int(cms) if cms not in (None, "") else None - except (TypeError, ValueError): - oms_i = cms_i = None - if oms_i and cms_i and cms_i > oms_i: - return max(0, int((cms_i - oms_i) // 60_000)) - return 0 - - -def _effective_pnl(d: dict) -> float: - reviewed = d.get("reviewed_pnl_amount") - if reviewed is not None and str(reviewed).strip() != "": - try: - return float(reviewed) - except (TypeError, ValueError): - pass - ex = d.get("exchange_realized_pnl") - if ex is not None and str(ex).strip() != "": - try: - return float(ex) - except (TypeError, ValueError): - pass - try: - return float(d.get("pnl_amount") or 0) - except (TypeError, ValueError): - return 0.0 - - -def _trade_close_dt(d: dict) -> datetime | None: - raw = _effective_field(d, "reviewed_closed_at", "closed_at") - if raw is None or str(raw).strip() == "": - raw = d.get("created_at") or d.get("opened_at") - return parse_dt_for_trading_day(raw) - - -def _normalize_trade_row( - d: dict, - *, - trading_day: str, - reset_hour: int, -) -> dict[str, Any] | None: - effective_result = str(_effective_field(d, "reviewed_result", "result") or "").strip() - if effective_result not in TRADE_COMPLETED_RESULTS: - return None - close_dt = _trade_close_dt(d) - if not close_dt: - return None - if trading_day_from_dt(close_dt, reset_hour) != trading_day: - return None - pnl = _effective_pnl(d) - closed_at = _effective_field(d, "reviewed_closed_at", "closed_at") - opened_at = _effective_field(d, "reviewed_opened_at", "opened_at") - return { - "symbol": d.get("symbol"), - "direction": d.get("direction"), - "result": effective_result, - "pnl_amount": round(pnl, 4), - "closed_at": closed_at, - "opened_at": opened_at, - "monitor_type": d.get("monitor_type"), - "actual_rr": d.get("actual_rr"), - "planned_rr": d.get("planned_rr"), - "trade_style": d.get("trade_style"), - "entry_reason": d.get("entry_reason"), - "reviewed": bool(d.get("reviewed_at") or d.get("reviewed_result")), - } - - -def fetch_trades_for_trading_day( - conn, - trading_day: str, - *, - row_to_dict_fn: Optional[Callable] = None, - reset_hour: int = 8, - limit: int = 200, -) -> list[dict[str, Any]]: - """返回指定交易日的已平仓记录(与 /records 交易记录一致,复盘字段优先)。""" - day = (trading_day or "").strip()[:10] - if not day: - return [] - lim = max(1, min(int(limit or 200), 500)) - start_bj, end_bj = trading_day_window_bounds(day, reset_hour) - ts_expr = "REPLACE(COALESCE(reviewed_closed_at, closed_at, created_at, opened_at), 'T', ' ')" - rows = conn.execute( - f""" - SELECT symbol, direction, result, reviewed_result, pnl_amount, reviewed_pnl_amount, - exchange_realized_pnl, closed_at, reviewed_closed_at, opened_at, reviewed_opened_at, - created_at, monitor_type, actual_rr, planned_rr, trade_style, entry_reason, - reviewed_at - FROM trade_records - WHERE {ts_expr} >= ? AND {ts_expr} <= ? - ORDER BY {ts_expr} ASC - LIMIT ? - """, - (start_bj, end_bj, lim * 3), - ).fetchall() - out: list[dict[str, Any]] = [] - for row in rows: - d = _row_dict(row, row_to_dict_fn) - norm = _normalize_trade_row(d, trading_day=day, reset_hour=reset_hour) - if norm: - out.append(norm) - if len(out) >= lim: - break - return out - - -def _normalize_archive_trade_row( - d: dict, - *, - exchange_key: str = "", - reset_hour: int = 8, -) -> dict[str, Any] | None: - """全历史档案用:已平仓记录(不按交易日截断)。""" - effective_result = str(_effective_field(d, "reviewed_result", "result") or "").strip() - if effective_result not in TRADE_COMPLETED_RESULTS: - return None - close_dt = _trade_close_dt(d) - if not close_dt: - return None - pnl = _effective_pnl(d) - closed_at = _effective_field(d, "reviewed_closed_at", "closed_at") - opened_at = _effective_field(d, "reviewed_opened_at", "opened_at") - opened_ms = d.get("opened_at_ms") - closed_ms = d.get("closed_at_ms") - if opened_ms in (None, ""): - odt = parse_dt_for_trading_day(opened_at) - opened_ms = int(odt.timestamp() * 1000) if odt else None - if closed_ms in (None, ""): - cdt = close_dt - closed_ms = int(cdt.timestamp() * 1000) if cdt else None - try: - trade_id = int(d.get("id")) - except (TypeError, ValueError): - return None - opened_ms_i = int(opened_ms) if opened_ms else None - closed_ms_i = int(closed_ms) if closed_ms else None - hold_m = effective_hold_minutes(d, opened_ms=opened_ms_i, closed_ms=closed_ms_i) - entry_type = display_entry_type_label(d) - reviewed = bool( - d.get("reviewed_at") - or d.get("reviewed_result") - or d.get("reviewed_opened_at") - or d.get("reviewed_closed_at") - or d.get("reviewed_entry_reason") - or d.get("reviewed_hold_minutes") - ) - return { - "id": trade_id, - "exchange_key": (exchange_key or "").strip().lower(), - "symbol": (d.get("symbol") or "").strip().upper(), - "direction": d.get("direction"), - "result": effective_result, - "pnl_amount": round(pnl, 4), - "closed_at": closed_at, - "opened_at": opened_at, - "opened_at_ms": opened_ms_i, - "closed_at_ms": closed_ms_i, - "monitor_type": _normalize_monitor_type_label(d.get("monitor_type")), - "entry_type": entry_type, - "entry_reason": entry_type, - "hold_minutes": hold_m, - "hold_minutes_text": format_hold_minutes(hold_m), - "actual_rr": d.get("actual_rr"), - "planned_rr": d.get("planned_rr"), - "trade_style": d.get("trade_style"), - "trigger_price": d.get("trigger_price"), - "stop_loss": _effective_field(d, "reviewed_stop_loss", "stop_loss"), - "take_profit": _effective_field(d, "reviewed_take_profit", "take_profit"), - "reviewed": reviewed, - "trading_day": trading_day_from_dt(close_dt, reset_hour), - "exchange_turnover_usdt": d.get("exchange_turnover_usdt"), - "exchange_commission_usdt": d.get("exchange_commission_usdt"), - } - - -_SNAPSHOT_STATUS_TO_RESULT = { - "stopped_sl": "止损", - "stopped_tp": "止盈", - "stopped_manual": "手动平仓", - "stopped_external": "外部平仓", -} - - -def _table_columns(conn, table: str) -> set[str]: - try: - rows = conn.execute(f"PRAGMA table_info({table})").fetchall() - except Exception: - return set() - out: set[str] = set() - for r in rows: - try: - out.add(str(r[1])) - except (IndexError, TypeError): - try: - out.add(str(r["name"])) - except Exception: - continue - return out - - -def _archive_ts_expr(cols: set[str]) -> str: - parts = [c for c in ("reviewed_closed_at", "closed_at", "created_at", "opened_at") if c in cols] - if not parts: - return "''" - return f"REPLACE(COALESCE({', '.join(parts)}), 'T', ' ')" - - -def _archive_trade_select_sql(cols: set[str]) -> str: - wanted = [ - "id", - "symbol", - "direction", - "result", - "reviewed_result", - "pnl_amount", - "reviewed_pnl_amount", - "exchange_realized_pnl", - "closed_at", - "reviewed_closed_at", - "opened_at", - "reviewed_opened_at", - "opened_at_ms", - "closed_at_ms", - "created_at", - "monitor_type", - "key_signal_type", - "actual_rr", - "planned_rr", - "trade_style", - "entry_reason", - "reviewed_entry_reason", - "hold_minutes", - "reviewed_hold_minutes", - "hold_seconds", - "reviewed_hold_seconds", - "trigger_price", - "stop_loss", - "take_profit", - "reviewed_stop_loss", - "reviewed_take_profit", - "reviewed_at", - "trend_plan_id", - "exchange_turnover_usdt", - "exchange_commission_usdt", - ] - select_cols = [c for c in wanted if c in cols] - if "id" not in select_cols: - select_cols = ["id"] + select_cols - return ", ".join(select_cols) - - -def _existing_trend_plan_ids(conn) -> set[int]: - cols = _table_columns(conn, "trade_records") - if "trend_plan_id" not in cols: - return set() - rows = conn.execute( - "SELECT DISTINCT trend_plan_id FROM trade_records WHERE trend_plan_id IS NOT NULL" - ).fetchall() - out: set[int] = set() - for row in rows: - d = _row_dict(row) - try: - out.add(int(d.get("trend_plan_id"))) - except (TypeError, ValueError): - continue - return out - - -def _normalize_snapshot_archive_row( - snap: dict, - *, - exchange_key: str = "", - reset_hour: int = 8, -) -> dict[str, Any] | None: - result = str(snap.get("result_label") or "").strip() - if not result: - result = _SNAPSHOT_STATUS_TO_RESULT.get( - str(snap.get("status_at_close") or "").strip(), "" - ) - if result not in TRADE_COMPLETED_RESULTS: - return None - closed_at = snap.get("closed_at") - close_dt = parse_dt_for_trading_day(closed_at) - if not close_dt: - return None - opened_at = snap.get("opened_at") - opened_ms = _parse_ms_from_row(snap.get("opened_at")) - closed_ms = _parse_ms_from_row(closed_at) - try: - snap_id = int(snap.get("id")) - except (TypeError, ValueError): - return None - try: - pnl = float(snap.get("pnl_amount") or 0) - except (TypeError, ValueError): - pnl = 0.0 - st = str(snap.get("strategy_type") or "").strip() - monitor_type = _normalize_monitor_type_label( - "trend_pullback" if st == "trend_pullback" else ("roll" if st == "roll" else st) - ) - hold_m = effective_hold_minutes( - {}, - opened_ms=opened_ms, - closed_ms=closed_ms, - ) - entry_type = entry_reason_for_monitor_type(monitor_type) or monitor_type - return { - "id": -snap_id, - "exchange_key": (exchange_key or "").strip().lower(), - "symbol": (snap.get("symbol") or "").strip().upper(), - "direction": snap.get("direction"), - "result": result, - "pnl_amount": round(pnl, 4), - "closed_at": closed_at, - "opened_at": opened_at, - "opened_at_ms": opened_ms, - "closed_at_ms": closed_ms, - "monitor_type": monitor_type, - "entry_type": entry_type, - "entry_reason": entry_type, - "hold_minutes": hold_m, - "hold_minutes_text": format_hold_minutes(hold_m), - "from_snapshot": True, - "snapshot_id": snap_id, - "trend_plan_id": snap.get("source_id"), - "reviewed": False, - "trading_day": trading_day_from_dt(close_dt, reset_hour), - } - - -def _parse_ms_from_row(raw: Any) -> int | None: - if raw in (None, ""): - return None - try: - if isinstance(raw, (int, float)): - v = int(raw) - return v if v > 1_000_000_000_000 else v * 1000 - except (TypeError, ValueError): - pass - dt = parse_dt_for_trading_day(raw) - return int(dt.timestamp() * 1000) if dt else None - - -def _fetch_strategy_snapshots_for_archive( - conn, - *, - exchange_key: str = "", - days: int = 365, - reset_hour: int = 8, - limit: int = 2000, - skip_plan_ids: set[int] | None = None, -) -> list[dict[str, Any]]: - cols = _table_columns(conn, "strategy_trade_snapshots") - if not cols: - return [] - lim = max(1, min(int(limit or 2000), 5000)) - day_span = max(1, min(int(days or 365), 3650)) - cutoff = datetime.now() - timedelta(days=day_span) - cutoff_s = cutoff.strftime("%Y-%m-%d %H:%M:%S") - ts_expr = "REPLACE(COALESCE(closed_at, opened_at, created_at), 'T', ' ')" - rows = conn.execute( - f""" - SELECT * FROM strategy_trade_snapshots - WHERE {ts_expr} >= ? - ORDER BY {ts_expr} DESC - LIMIT ? - """, - (cutoff_s, lim * 2), - ).fetchall() - skip = skip_plan_ids or set() - out: list[dict[str, Any]] = [] - for row in rows: - d = _row_dict(row) - try: - source_id = int(d.get("source_id") or 0) - except (TypeError, ValueError): - source_id = 0 - if source_id > 0 and source_id in skip: - continue - norm = _normalize_snapshot_archive_row( - d, exchange_key=exchange_key, reset_hour=reset_hour - ) - if norm: - out.append(norm) - if len(out) >= lim: - break - return out - - -def fetch_trades_for_archive( - conn, - *, - exchange_key: str = "", - days: int = 365, - row_to_dict_fn: Optional[Callable] = None, - reset_hour: int = 8, - limit: int = 2000, - include_strategy_snapshots: bool = True, -) -> list[dict[str, Any]]: - """返回近 N 天已平仓记录(trade_records + 未落库的 strategy 快照)。""" - lim = max(1, min(int(limit or 2000), 5000)) - day_span = max(1, min(int(days or 365), 3650)) - cutoff = datetime.now() - timedelta(days=day_span) - cutoff_s = cutoff.strftime("%Y-%m-%d %H:%M:%S") - cols = _table_columns(conn, "trade_records") - if not cols: - records: list[dict[str, Any]] = [] - else: - ts_expr = _archive_ts_expr(cols) - sql = f""" - SELECT {_archive_trade_select_sql(cols)} - FROM trade_records - WHERE {ts_expr} >= ? - ORDER BY {ts_expr} DESC - LIMIT ? - """ - rows = conn.execute(sql, (cutoff_s, lim * 2)).fetchall() - records = [] - for row in rows: - d = _row_dict(row, row_to_dict_fn) - norm = _normalize_archive_trade_row( - d, exchange_key=exchange_key, reset_hour=reset_hour - ) - if norm: - records.append(norm) - if len(records) >= lim: - break - - if not include_strategy_snapshots: - return records - - skip_ids = _existing_trend_plan_ids(conn) - for rec in records: - try: - pid = int(rec.get("trend_plan_id") or 0) - except (TypeError, ValueError): - pid = 0 - if pid > 0: - skip_ids.add(pid) - - snaps = _fetch_strategy_snapshots_for_archive( - conn, - days=days, - exchange_key=exchange_key, - reset_hour=reset_hour, - limit=max(0, lim - len(records)), - skip_plan_ids=skip_ids, - ) - merged = records + snaps - merged.sort( - key=lambda x: int(x.get("closed_at_ms") or 0), - reverse=True, - ) - return merged[:lim] - - -def summarize_trades(trades: list[dict]) -> dict[str, Any]: - """单笔列表 → 笔数 / 盈亏 / 胜败统计。""" - total_pnl = 0.0 - win = loss = flat = 0 - for t in trades or []: - try: - pnl = float(t.get("pnl_amount") or 0) - except (TypeError, ValueError): - pnl = 0.0 - total_pnl += pnl - if pnl > 1e-9: - win += 1 - elif pnl < -1e-9: - loss += 1 - else: - flat += 1 - return { - "closed_count": len(trades or []), - "win_count": win, - "loss_count": loss, - "flat_count": flat, - "total_pnl_u": round(total_pnl, 4), - } +"""各实例当日平仓记录查询(供 hub_bridge /api/hub/trades/today 与中控 AI 聚合)。""" +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import Any, Callable, Optional + +from lib.strategy.strategy_trade_labels import ( + MONITOR_TYPE_ROLL, + MONITOR_TYPE_TREND_PULLBACK, + entry_reason_for_monitor_type, +) +from lib.trade.time_close_lib import TIME_CLOSE_RESULT + +TRADE_COMPLETED_RESULTS = ( + "止盈", + "止损", + "保本止盈", + "移动止盈", + "手动平仓", + "强制清仓", + "外部平仓", + TIME_CLOSE_RESULT, +) + + +def trading_day_from_dt(dt: datetime, reset_hour: int = 8) -> str: + """与实例 get_trading_day 一致:小时 < reset_hour 归属上一日历日。""" + if dt.hour < reset_hour: + dt = dt - timedelta(days=1) + return dt.strftime("%Y-%m-%d") + + +def current_trading_day(*, now: datetime | None = None, reset_hour: int = 8) -> str: + return trading_day_from_dt(now or datetime.now(), reset_hour) + + +def parse_dt_for_trading_day(raw: Any) -> datetime | None: + if raw is None: + return None + s = str(raw).strip().replace("Z", "").replace("T", " ") + if not s: + return None + for fmt, ln in (("%Y-%m-%d %H:%M:%S", 19), ("%Y-%m-%d %H:%M", 16), ("%Y-%m-%d", 10)): + try: + return datetime.strptime(s[:ln], fmt) + except ValueError: + continue + return None + + +def trading_day_window_bounds(trading_day: str, reset_hour: int = 8) -> tuple[str, str]: + """交易日 [reset_hour, 次日 reset_hour) 对应的北京时间字符串区间(闭区间)。""" + day = datetime.strptime((trading_day or "").strip()[:10], "%Y-%m-%d") + start = day.replace(hour=reset_hour, minute=0, second=0, microsecond=0) + end = start + timedelta(days=1) - timedelta(seconds=1) + return start.strftime("%Y-%m-%d %H:%M:%S"), end.strftime("%Y-%m-%d %H:%M:%S") + + +def _row_dict(row, row_to_dict: Optional[Callable] = None) -> dict: + if row is None: + return {} + if row_to_dict: + try: + return dict(row_to_dict(row)) + except Exception: + pass + try: + keys = row.keys() if hasattr(row, "keys") else () + if keys: + return {k: row[k] for k in keys} + except Exception: + pass + try: + return dict(row) + except Exception: + return {} + + +def _effective_field(d: dict, reviewed_key: str, base_key: str, default: Any = None) -> Any: + rv = d.get(reviewed_key) + if rv is not None and str(rv).strip() != "": + return rv + bv = d.get(base_key) + if bv is not None and str(bv).strip() != "": + return bv + return default + + +def format_hold_minutes(minutes: Any) -> str: + try: + total = int(minutes or 0) + except (TypeError, ValueError): + return "0分钟" + if total <= 0: + return "0分钟" + hours = total // 60 + mins = total % 60 + if hours: + return f"{hours}小时{mins}分钟" + return f"{mins}分钟" + + +def _normalize_monitor_type_label(raw: Any) -> str: + mt = str(raw or "").strip() + if mt in ("trend_pullback", "trend"): + return MONITOR_TYPE_TREND_PULLBACK + if mt in ("roll",): + return MONITOR_TYPE_ROLL + return mt + + +def effective_entry_type(d: dict) -> str: + """复盘开仓类型优先,与实例交易记录 effective_entry_reason 一致。""" + er = _effective_field(d, "reviewed_entry_reason", "entry_reason") + if er is not None and str(er).strip(): + return str(er).strip() + mt = _normalize_monitor_type_label(d.get("monitor_type")) + er2 = entry_reason_for_monitor_type(mt) + if er2: + return er2 + kst = str(d.get("key_signal_type") or "").strip() + if kst: + return kst + legacy = str(d.get("entry_type") or "").strip() + if legacy and legacy not in ("trend_pullback", "roll", "trend"): + return _normalize_monitor_type_label(legacy) or legacy + return mt + + +def display_entry_type_label(d: dict) -> str: + """档案/列表展示用开仓类型(不回落为「下单监控」若已有复盘或建档类型)。""" + label = effective_entry_type(d).strip() + if not label: + return "—" + return _normalize_monitor_type_label(label) or label + + +def effective_hold_minutes( + d: dict, + *, + opened_ms: int | None = None, + closed_ms: int | None = None, +) -> int: + hm = _effective_field(d, "reviewed_hold_minutes", "hold_minutes") + if hm is not None and str(hm).strip() != "": + try: + return max(0, int(hm)) + except (TypeError, ValueError): + pass + hs = _effective_field(d, "reviewed_hold_seconds", "hold_seconds") + if hs is not None and str(hs).strip() != "": + try: + return max(0, int(int(hs) // 60)) + except (TypeError, ValueError): + pass + oms = opened_ms if opened_ms is not None else d.get("opened_at_ms") + cms = closed_ms if closed_ms is not None else d.get("closed_at_ms") + try: + oms_i = int(oms) if oms not in (None, "") else None + cms_i = int(cms) if cms not in (None, "") else None + except (TypeError, ValueError): + oms_i = cms_i = None + if oms_i and cms_i and cms_i > oms_i: + return max(0, int((cms_i - oms_i) // 60_000)) + return 0 + + +def _effective_pnl(d: dict) -> float: + reviewed = d.get("reviewed_pnl_amount") + if reviewed is not None and str(reviewed).strip() != "": + try: + return float(reviewed) + except (TypeError, ValueError): + pass + ex = d.get("exchange_realized_pnl") + if ex is not None and str(ex).strip() != "": + try: + return float(ex) + except (TypeError, ValueError): + pass + try: + return float(d.get("pnl_amount") or 0) + except (TypeError, ValueError): + return 0.0 + + +def _trade_close_dt(d: dict) -> datetime | None: + raw = _effective_field(d, "reviewed_closed_at", "closed_at") + if raw is None or str(raw).strip() == "": + raw = d.get("created_at") or d.get("opened_at") + return parse_dt_for_trading_day(raw) + + +def _normalize_trade_row( + d: dict, + *, + trading_day: str, + reset_hour: int, +) -> dict[str, Any] | None: + effective_result = str(_effective_field(d, "reviewed_result", "result") or "").strip() + if effective_result not in TRADE_COMPLETED_RESULTS: + return None + close_dt = _trade_close_dt(d) + if not close_dt: + return None + if trading_day_from_dt(close_dt, reset_hour) != trading_day: + return None + pnl = _effective_pnl(d) + closed_at = _effective_field(d, "reviewed_closed_at", "closed_at") + opened_at = _effective_field(d, "reviewed_opened_at", "opened_at") + return { + "symbol": d.get("symbol"), + "direction": d.get("direction"), + "result": effective_result, + "pnl_amount": round(pnl, 4), + "closed_at": closed_at, + "opened_at": opened_at, + "monitor_type": d.get("monitor_type"), + "actual_rr": d.get("actual_rr"), + "planned_rr": d.get("planned_rr"), + "trade_style": d.get("trade_style"), + "entry_reason": d.get("entry_reason"), + "reviewed": bool(d.get("reviewed_at") or d.get("reviewed_result")), + } + + +def fetch_trades_for_trading_day( + conn, + trading_day: str, + *, + row_to_dict_fn: Optional[Callable] = None, + reset_hour: int = 8, + limit: int = 200, +) -> list[dict[str, Any]]: + """返回指定交易日的已平仓记录(与 /records 交易记录一致,复盘字段优先)。""" + day = (trading_day or "").strip()[:10] + if not day: + return [] + lim = max(1, min(int(limit or 200), 500)) + start_bj, end_bj = trading_day_window_bounds(day, reset_hour) + ts_expr = "REPLACE(COALESCE(reviewed_closed_at, closed_at, created_at, opened_at), 'T', ' ')" + rows = conn.execute( + f""" + SELECT symbol, direction, result, reviewed_result, pnl_amount, reviewed_pnl_amount, + exchange_realized_pnl, closed_at, reviewed_closed_at, opened_at, reviewed_opened_at, + created_at, monitor_type, actual_rr, planned_rr, trade_style, entry_reason, + reviewed_at + FROM trade_records + WHERE {ts_expr} >= ? AND {ts_expr} <= ? + ORDER BY {ts_expr} ASC + LIMIT ? + """, + (start_bj, end_bj, lim * 3), + ).fetchall() + out: list[dict[str, Any]] = [] + for row in rows: + d = _row_dict(row, row_to_dict_fn) + norm = _normalize_trade_row(d, trading_day=day, reset_hour=reset_hour) + if norm: + out.append(norm) + if len(out) >= lim: + break + return out + + +def _normalize_archive_trade_row( + d: dict, + *, + exchange_key: str = "", + reset_hour: int = 8, +) -> dict[str, Any] | None: + """全历史档案用:已平仓记录(不按交易日截断)。""" + effective_result = str(_effective_field(d, "reviewed_result", "result") or "").strip() + if effective_result not in TRADE_COMPLETED_RESULTS: + return None + close_dt = _trade_close_dt(d) + if not close_dt: + return None + pnl = _effective_pnl(d) + closed_at = _effective_field(d, "reviewed_closed_at", "closed_at") + opened_at = _effective_field(d, "reviewed_opened_at", "opened_at") + opened_ms = d.get("opened_at_ms") + closed_ms = d.get("closed_at_ms") + if opened_ms in (None, ""): + odt = parse_dt_for_trading_day(opened_at) + opened_ms = int(odt.timestamp() * 1000) if odt else None + if closed_ms in (None, ""): + cdt = close_dt + closed_ms = int(cdt.timestamp() * 1000) if cdt else None + try: + trade_id = int(d.get("id")) + except (TypeError, ValueError): + return None + opened_ms_i = int(opened_ms) if opened_ms else None + closed_ms_i = int(closed_ms) if closed_ms else None + hold_m = effective_hold_minutes(d, opened_ms=opened_ms_i, closed_ms=closed_ms_i) + entry_type = display_entry_type_label(d) + reviewed = bool( + d.get("reviewed_at") + or d.get("reviewed_result") + or d.get("reviewed_opened_at") + or d.get("reviewed_closed_at") + or d.get("reviewed_entry_reason") + or d.get("reviewed_hold_minutes") + ) + return { + "id": trade_id, + "exchange_key": (exchange_key or "").strip().lower(), + "symbol": (d.get("symbol") or "").strip().upper(), + "direction": d.get("direction"), + "result": effective_result, + "pnl_amount": round(pnl, 4), + "closed_at": closed_at, + "opened_at": opened_at, + "opened_at_ms": opened_ms_i, + "closed_at_ms": closed_ms_i, + "monitor_type": _normalize_monitor_type_label(d.get("monitor_type")), + "entry_type": entry_type, + "entry_reason": entry_type, + "hold_minutes": hold_m, + "hold_minutes_text": format_hold_minutes(hold_m), + "actual_rr": d.get("actual_rr"), + "planned_rr": d.get("planned_rr"), + "trade_style": d.get("trade_style"), + "trigger_price": d.get("trigger_price"), + "stop_loss": _effective_field(d, "reviewed_stop_loss", "stop_loss"), + "take_profit": _effective_field(d, "reviewed_take_profit", "take_profit"), + "reviewed": reviewed, + "trading_day": trading_day_from_dt(close_dt, reset_hour), + "exchange_turnover_usdt": d.get("exchange_turnover_usdt"), + "exchange_commission_usdt": d.get("exchange_commission_usdt"), + } + + +_SNAPSHOT_STATUS_TO_RESULT = { + "stopped_sl": "止损", + "stopped_tp": "止盈", + "stopped_manual": "手动平仓", + "stopped_external": "外部平仓", +} + + +def _table_columns(conn, table: str) -> set[str]: + try: + rows = conn.execute(f"PRAGMA table_info({table})").fetchall() + except Exception: + return set() + out: set[str] = set() + for r in rows: + try: + out.add(str(r[1])) + except (IndexError, TypeError): + try: + out.add(str(r["name"])) + except Exception: + continue + return out + + +def _archive_ts_expr(cols: set[str]) -> str: + parts = [c for c in ("reviewed_closed_at", "closed_at", "created_at", "opened_at") if c in cols] + if not parts: + return "''" + return f"REPLACE(COALESCE({', '.join(parts)}), 'T', ' ')" + + +def _archive_trade_select_sql(cols: set[str]) -> str: + wanted = [ + "id", + "symbol", + "direction", + "result", + "reviewed_result", + "pnl_amount", + "reviewed_pnl_amount", + "exchange_realized_pnl", + "closed_at", + "reviewed_closed_at", + "opened_at", + "reviewed_opened_at", + "opened_at_ms", + "closed_at_ms", + "created_at", + "monitor_type", + "key_signal_type", + "actual_rr", + "planned_rr", + "trade_style", + "entry_reason", + "reviewed_entry_reason", + "hold_minutes", + "reviewed_hold_minutes", + "hold_seconds", + "reviewed_hold_seconds", + "trigger_price", + "stop_loss", + "take_profit", + "reviewed_stop_loss", + "reviewed_take_profit", + "reviewed_at", + "trend_plan_id", + "exchange_turnover_usdt", + "exchange_commission_usdt", + ] + select_cols = [c for c in wanted if c in cols] + if "id" not in select_cols: + select_cols = ["id"] + select_cols + return ", ".join(select_cols) + + +def _existing_trend_plan_ids(conn) -> set[int]: + cols = _table_columns(conn, "trade_records") + if "trend_plan_id" not in cols: + return set() + rows = conn.execute( + "SELECT DISTINCT trend_plan_id FROM trade_records WHERE trend_plan_id IS NOT NULL" + ).fetchall() + out: set[int] = set() + for row in rows: + d = _row_dict(row) + try: + out.add(int(d.get("trend_plan_id"))) + except (TypeError, ValueError): + continue + return out + + +def _normalize_snapshot_archive_row( + snap: dict, + *, + exchange_key: str = "", + reset_hour: int = 8, +) -> dict[str, Any] | None: + result = str(snap.get("result_label") or "").strip() + if not result: + result = _SNAPSHOT_STATUS_TO_RESULT.get( + str(snap.get("status_at_close") or "").strip(), "" + ) + if result not in TRADE_COMPLETED_RESULTS: + return None + closed_at = snap.get("closed_at") + close_dt = parse_dt_for_trading_day(closed_at) + if not close_dt: + return None + opened_at = snap.get("opened_at") + opened_ms = _parse_ms_from_row(snap.get("opened_at")) + closed_ms = _parse_ms_from_row(closed_at) + try: + snap_id = int(snap.get("id")) + except (TypeError, ValueError): + return None + try: + pnl = float(snap.get("pnl_amount") or 0) + except (TypeError, ValueError): + pnl = 0.0 + st = str(snap.get("strategy_type") or "").strip() + monitor_type = _normalize_monitor_type_label( + "trend_pullback" if st == "trend_pullback" else ("roll" if st == "roll" else st) + ) + hold_m = effective_hold_minutes( + {}, + opened_ms=opened_ms, + closed_ms=closed_ms, + ) + entry_type = entry_reason_for_monitor_type(monitor_type) or monitor_type + return { + "id": -snap_id, + "exchange_key": (exchange_key or "").strip().lower(), + "symbol": (snap.get("symbol") or "").strip().upper(), + "direction": snap.get("direction"), + "result": result, + "pnl_amount": round(pnl, 4), + "closed_at": closed_at, + "opened_at": opened_at, + "opened_at_ms": opened_ms, + "closed_at_ms": closed_ms, + "monitor_type": monitor_type, + "entry_type": entry_type, + "entry_reason": entry_type, + "hold_minutes": hold_m, + "hold_minutes_text": format_hold_minutes(hold_m), + "from_snapshot": True, + "snapshot_id": snap_id, + "trend_plan_id": snap.get("source_id"), + "reviewed": False, + "trading_day": trading_day_from_dt(close_dt, reset_hour), + } + + +def _parse_ms_from_row(raw: Any) -> int | None: + if raw in (None, ""): + return None + try: + if isinstance(raw, (int, float)): + v = int(raw) + return v if v > 1_000_000_000_000 else v * 1000 + except (TypeError, ValueError): + pass + dt = parse_dt_for_trading_day(raw) + return int(dt.timestamp() * 1000) if dt else None + + +def _fetch_strategy_snapshots_for_archive( + conn, + *, + exchange_key: str = "", + days: int = 365, + reset_hour: int = 8, + limit: int = 2000, + skip_plan_ids: set[int] | None = None, +) -> list[dict[str, Any]]: + cols = _table_columns(conn, "strategy_trade_snapshots") + if not cols: + return [] + lim = max(1, min(int(limit or 2000), 5000)) + day_span = max(1, min(int(days or 365), 3650)) + cutoff = datetime.now() - timedelta(days=day_span) + cutoff_s = cutoff.strftime("%Y-%m-%d %H:%M:%S") + ts_expr = "REPLACE(COALESCE(closed_at, opened_at, created_at), 'T', ' ')" + rows = conn.execute( + f""" + SELECT * FROM strategy_trade_snapshots + WHERE {ts_expr} >= ? + ORDER BY {ts_expr} DESC + LIMIT ? + """, + (cutoff_s, lim * 2), + ).fetchall() + skip = skip_plan_ids or set() + out: list[dict[str, Any]] = [] + for row in rows: + d = _row_dict(row) + try: + source_id = int(d.get("source_id") or 0) + except (TypeError, ValueError): + source_id = 0 + if source_id > 0 and source_id in skip: + continue + norm = _normalize_snapshot_archive_row( + d, exchange_key=exchange_key, reset_hour=reset_hour + ) + if norm: + out.append(norm) + if len(out) >= lim: + break + return out + + +def fetch_trades_for_archive( + conn, + *, + exchange_key: str = "", + days: int = 365, + row_to_dict_fn: Optional[Callable] = None, + reset_hour: int = 8, + limit: int = 2000, + include_strategy_snapshots: bool = True, +) -> list[dict[str, Any]]: + """返回近 N 天已平仓记录(trade_records + 未落库的 strategy 快照)。""" + lim = max(1, min(int(limit or 2000), 5000)) + day_span = max(1, min(int(days or 365), 3650)) + cutoff = datetime.now() - timedelta(days=day_span) + cutoff_s = cutoff.strftime("%Y-%m-%d %H:%M:%S") + cols = _table_columns(conn, "trade_records") + if not cols: + records: list[dict[str, Any]] = [] + else: + ts_expr = _archive_ts_expr(cols) + sql = f""" + SELECT {_archive_trade_select_sql(cols)} + FROM trade_records + WHERE {ts_expr} >= ? + ORDER BY {ts_expr} DESC + LIMIT ? + """ + rows = conn.execute(sql, (cutoff_s, lim * 2)).fetchall() + records = [] + for row in rows: + d = _row_dict(row, row_to_dict_fn) + norm = _normalize_archive_trade_row( + d, exchange_key=exchange_key, reset_hour=reset_hour + ) + if norm: + records.append(norm) + if len(records) >= lim: + break + + if not include_strategy_snapshots: + return records + + skip_ids = _existing_trend_plan_ids(conn) + for rec in records: + try: + pid = int(rec.get("trend_plan_id") or 0) + except (TypeError, ValueError): + pid = 0 + if pid > 0: + skip_ids.add(pid) + + snaps = _fetch_strategy_snapshots_for_archive( + conn, + days=days, + exchange_key=exchange_key, + reset_hour=reset_hour, + limit=max(0, lim - len(records)), + skip_plan_ids=skip_ids, + ) + merged = records + snaps + merged.sort( + key=lambda x: int(x.get("closed_at_ms") or 0), + reverse=True, + ) + return merged[:lim] + + +def summarize_trades(trades: list[dict]) -> dict[str, Any]: + """单笔列表 → 笔数 / 盈亏 / 胜败统计。""" + total_pnl = 0.0 + win = loss = flat = 0 + for t in trades or []: + try: + pnl = float(t.get("pnl_amount") or 0) + except (TypeError, ValueError): + pnl = 0.0 + total_pnl += pnl + if pnl > 1e-9: + win += 1 + elif pnl < -1e-9: + loss += 1 + else: + flat += 1 + return { + "closed_count": len(trades or []), + "win_count": win, + "loss_count": loss, + "flat_count": flat, + "total_pnl_u": round(total_pnl, 4), + } diff --git a/hub_volume_rank_lib.py b/lib/hub/hub_volume_rank_lib.py similarity index 96% rename from hub_volume_rank_lib.py rename to lib/hub/hub_volume_rank_lib.py index 1e7a2c4..e4ba950 100644 --- a/hub_volume_rank_lib.py +++ b/lib/hub/hub_volume_rank_lib.py @@ -1,595 +1,595 @@ -"""行情区:各交易所 USDT 永续昨日成交额 Top N(每日 8:00 快照)。""" - -from __future__ import annotations - -import json -import os -from datetime import datetime, timedelta -from pathlib import Path -from typing import Any, Callable -from zoneinfo import ZoneInfo - -from hub_trades_lib import trading_day_from_dt - -TOP_N_DEFAULT = 20 -CACHE_VERSION = 3 -LIQUIDITY_RANK_CACHE_VERSION = 1 - - -def volume_rank_reset_hour() -> int: - try: - return max(0, min(23, int(os.getenv("HUB_VOLUME_RANK_RESET_HOUR", "8")))) - except ValueError: - return 8 - - -def volume_rank_timezone() -> ZoneInfo: - name = (os.getenv("HUB_VOLUME_RANK_TZ") or "Asia/Shanghai").strip() or "Asia/Shanghai" - try: - return ZoneInfo(name) - except Exception: - return ZoneInfo("Asia/Shanghai") - - -def rank_date_label(*, now: datetime | None = None, reset_hour: int | None = None) -> str: - """8 点更新后展示的「昨日」交易日(与 TRADING_DAY_RESET_HOUR 口径一致)。""" - rh = volume_rank_reset_hour() if reset_hour is None else reset_hour - tz = volume_rank_timezone() - dt = now.astimezone(tz) if now else datetime.now(tz) - cur_td = trading_day_from_dt(dt.replace(tzinfo=None), rh) - cur = datetime.strptime(cur_td, "%Y-%m-%d").date() - return (cur - timedelta(days=1)).isoformat() - - -def seconds_until_next_reset( - *, - now: datetime | None = None, - reset_hour: int | None = None, -) -> float: - rh = volume_rank_reset_hour() if reset_hour is None else reset_hour - tz = volume_rank_timezone() - dt = now.astimezone(tz) if now else datetime.now(tz) - nxt = dt.replace(hour=rh, minute=0, second=0, microsecond=0) - if dt >= nxt: - nxt += timedelta(days=1) - return max(1.0, (nxt - dt).total_seconds()) - - -def default_cache_path() -> Path: - raw = (os.getenv("HUB_VOLUME_RANK_CACHE_PATH") or "").strip() - if raw: - return Path(raw) - hub_dir = Path(__file__).resolve().parent / "manual_trading_hub" / "data" - hub_dir.mkdir(parents=True, exist_ok=True) - return hub_dir / "hub_volume_rank.json" - - -def _safe_float(v: Any) -> float | None: - try: - n = float(v) - return n if n == n else None - except (TypeError, ValueError): - return None - - -def _ticker_base(sym_text: str) -> str: - s = str(sym_text or "").upper().strip() - if ":" in s: - s = s.split(":", 1)[0] - if "/" in s: - return s.split("/", 1)[0].strip() - if "-" in s: - return s.split("-", 1)[0].strip() - if s.endswith("USDT"): - return s[:-4].strip() - return s - - -def _hub_symbol_from_base(base: str, quote: str = "USDT") -> str: - b = str(base or "").strip().upper() - q = str(quote or "USDT").strip().upper() - return f"{b}/{q}" if b else "" - - -def _hub_symbol_from_market(market: dict | None, fallback_symbol: str) -> str: - if market: - base = str(market.get("base") or "").strip().upper() - quote = str(market.get("quote") or "USDT").strip().upper() - if base: - return f"{base}/{quote}" - fb = str(fallback_symbol or "").upper().strip() - if ":" in fb: - fb = fb.split(":", 1)[0] - if "/" in fb: - return fb - base = _ticker_base(fb) - return f"{base}/USDT" if base else fb - - -def _okx_turnover_usdt(row: dict | None) -> float | None: - """OKX SWAP:成交额(USDT) ≈ volCcy24h(基础币) × last。""" - if not isinstance(row, dict): - return None - base_vol = _safe_float(row.get("volCcy24h")) - if base_vol is None or base_vol <= 0: - return None - last = _safe_float(row.get("last") or row.get("lastPx")) - if last is None or last <= 0: - return None - return float(base_vol * last) - - -def _quote_volume_from_ticker( - ticker: dict | None, - market: dict | None, - *, - exchange_id: str = "", -) -> float | None: - ex_id = str(exchange_id or "").lower() - t = ticker or {} - info = t.get("info") if isinstance(t.get("info"), dict) else {} - - if ex_id == "okx": - row = dict(info) - if row.get("last") is None: - row["last"] = t.get("last") - qv = _okx_turnover_usdt(row) - if qv is not None and qv > 0: - return qv - - qv = _safe_float(t.get("quoteVolume")) - if qv is not None and qv > 0: - return qv - - if ex_id in ("gateio", "gate"): - for key in ( - "volume_24h_quote", - "volume_24h_settle", - "quote_volume", - "vol_24h", - "turnover", - ): - qv = _safe_float(info.get(key)) - if qv is not None and qv > 0: - return qv - - for key in ("quoteVolume", "volCcy24h", "vol24h", "turnover24h", "amount24", "turnover"): - qv = _safe_float(info.get(key)) - if qv is not None and qv > 0: - if key == "volCcy24h" and ex_id == "okx": - last = _safe_float(info.get("last") or info.get("lastPx") or t.get("last")) - if last: - return qv * last - return qv - - bv = _safe_float(t.get("baseVolume")) - lp = _safe_float(t.get("last")) or _safe_float(t.get("close")) - if bv is not None and lp is not None and bv > 0 and lp > 0: - return bv * lp - - if info: - bv = _safe_float(info.get("volCcy24h") or info.get("vol24h") or info.get("volume")) - lp = _safe_float(info.get("last") or info.get("lastPx") or info.get("markPrice")) - if bv is not None and lp is not None and bv > 0 and lp > 0: - return bv * lp - - return None - - -def _is_usdt_linear_swap(market: dict | None, symbol: str) -> bool: - if not market: - su = str(symbol or "").upper() - return "USDT" in su and (":USDT" in su or "/USDT" in su or su.endswith("USDT")) - if not market.get("swap") and market.get("type") not in ("swap", "future"): - return False - if str(market.get("quote") or "").upper() != "USDT": - return False - if market.get("linear") is False: - return False - if market.get("active") is False: - return False - settle = str(market.get("settle") or "").upper() - if settle and settle != "USDT": - return False - return True - - -def _lookup_ticker(tickers: dict, sym: str, market: dict | None) -> dict | None: - if not tickers: - return None - t = tickers.get(sym) - if t: - return t - if not market: - return None - base = market.get("base") - quote = market.get("quote") or "USDT" - settle = market.get("settle") or quote - candidates = [ - sym, - f"{base}/{quote}:{settle}", - f"{base}/{quote}", - f"{base}{quote}", - market.get("id"), - ] - for key in candidates: - if not key: - continue - t = tickers.get(key) - if t: - return t - return None - - -def _merge_scores(scored: dict[str, tuple[str, float]]) -> list[tuple[str, str, float]]: - rows = [(sym, base, vol) for base, (sym, vol) in scored.items() if sym and base and vol > 0] - rows.sort(key=lambda x: x[2], reverse=True) - return rows - - -def _scores_from_okx(exchange) -> list[tuple[str, str, float]]: - by_base: dict[str, tuple[str, float]] = {} - if hasattr(exchange, "publicGetMarketTickers"): - try: - resp = exchange.publicGetMarketTickers({"instType": "SWAP"}) - for row in (resp or {}).get("data") or []: - if not isinstance(row, dict): - continue - inst = str(row.get("instId") or "").upper() - parts = inst.split("-") - if len(parts) < 3 or parts[-1] != "SWAP" or parts[1] != "USDT": - continue - base = parts[0].strip() - if not base: - continue - qv = _okx_turnover_usdt(row) - if qv is None or qv <= 0: - continue - sym = _hub_symbol_from_base(base) - prev = by_base.get(base) - if prev is None or qv > prev[1]: - by_base[base] = (sym, float(qv)) - if by_base: - return _merge_scores(by_base) - except Exception: - pass - - try: - tickers = exchange.fetch_tickers(params={"instType": "SWAP"}) - except Exception: - tickers = exchange.fetch_tickers() - return _scores_from_markets(exchange, tickers or {}, "okx") - - -def _scores_from_binance(exchange) -> list[tuple[str, str, float]]: - by_base: dict[str, tuple[str, float]] = {} - if hasattr(exchange, "fapiPublicGetTicker24hr"): - try: - rows = exchange.fapiPublicGetTicker24hr() - if isinstance(rows, list): - for row in rows: - if not isinstance(row, dict): - continue - raw = str(row.get("symbol") or "").upper() - if not raw.endswith("USDT"): - continue - base = raw[:-4] - if not base: - continue - qv = _safe_float(row.get("quoteVolume")) - if qv is None or qv <= 0: - bv = _safe_float(row.get("volume")) - lp = _safe_float(row.get("lastPrice") or row.get("weightedAvgPrice")) - if bv and lp: - qv = bv * lp - if qv is None or qv <= 0: - continue - sym = _hub_symbol_from_base(base) - prev = by_base.get(base) - if prev is None or qv > prev[1]: - by_base[base] = (sym, float(qv)) - if by_base: - return _merge_scores(by_base) - except Exception: - pass - return [] - - -def _scores_from_gate(exchange) -> list[tuple[str, str, float]]: - by_base: dict[str, tuple[str, float]] = {} - for method_name in ("publicFuturesGetSettleTickers", "publicFuturesGetUsdtTickers"): - fn = getattr(exchange, method_name, None) - if not callable(fn): - continue - try: - rows = fn({"settle": "usdt"}) - if isinstance(rows, list): - for row in rows: - if not isinstance(row, dict): - continue - contract = str(row.get("contract") or row.get("name") or "").upper() - if not contract: - continue - base = contract.replace("_USDT", "").replace("USDT", "").strip("_") - if not base: - continue - qv = _safe_float(row.get("volume_24h_quote") or row.get("volume_24h_settle")) - if qv is None or qv <= 0: - bv = _safe_float(row.get("volume_24h_base")) - lp = _safe_float(row.get("last") or row.get("mark_price")) - if bv and lp: - qv = bv * lp - if qv is None or qv <= 0: - continue - sym = _hub_symbol_from_base(base) - prev = by_base.get(base) - if prev is None or qv > prev[1]: - by_base[base] = (sym, float(qv)) - if by_base: - return _merge_scores(by_base) - except Exception: - continue - return [] - - -def _scores_from_markets( - exchange, - tickers: dict, - exchange_id: str, -) -> list[tuple[str, str, float]]: - by_base: dict[str, tuple[str, float]] = {} - markets = getattr(exchange, "markets", None) or {} - for sym, mk in markets.items(): - try: - if not _is_usdt_linear_swap(mk, sym): - continue - ticker = _lookup_ticker(tickers, sym, mk) - qv = _quote_volume_from_ticker(ticker, mk, exchange_id=exchange_id) - if qv is None or qv <= 0: - continue - hub_sym = _hub_symbol_from_market(mk, sym) - base = _ticker_base(hub_sym) - if not base: - continue - prev = by_base.get(base) - if prev is None or qv > prev[1]: - by_base[base] = (hub_sym, float(qv)) - except Exception: - continue - return _merge_scores(by_base) - - -def _collect_scores(exchange, exchange_id: str) -> list[tuple[str, str, float]]: - ex_id = str(exchange_id or "").lower() - if ex_id == "okx": - return _scores_from_okx(exchange) - if ex_id == "binance": - return _scores_from_binance(exchange) - if ex_id in ("gateio", "gate", "gate_bot"): - return _scores_from_gate(exchange) - tickers = exchange.fetch_tickers() - return _scores_from_markets(exchange, tickers or {}, ex_id) - - -def _uses_lightweight_volume_scores(exchange_id: str) -> bool: - ex_id = str(exchange_id or "").lower() - return ex_id in ("okx", "binance", "gateio", "gate", "gate_bot") - - -def build_usdt_swap_volume_ranks( - exchange, - ensure_markets_loaded: Callable[[], None], - *, - exchange_id: str | None = None, -) -> tuple[dict[str, int], int]: - """ - 全市场 USDT 永续 24h 成交额排名(base -> rank)。 - 优先各所轻量 ticker API,避免 fetch_tickers() 拉全市场(Gate/Binance 内存优化)。 - """ - ex_id = str(exchange_id or getattr(exchange, "id", "") or "").lower() - if not _uses_lightweight_volume_scores(ex_id): - ensure_markets_loaded() - scored = _collect_scores(exchange, ex_id) - ranks: dict[str, int] = {} - for idx, (_sym, base, _qv) in enumerate(scored, 1): - if base and base not in ranks: - ranks[base] = idx - return ranks, len(scored) - - -def resolve_daily_volume_rank( - target_base: str, - cache: dict[str, Any], - *, - now_ts: float, - ttl_sec: float, - exchange, - ensure_markets_loaded: Callable[[], None], - exchange_id: str | None = None, - cache_version: int = LIQUIDITY_RANK_CACHE_VERSION, -) -> tuple[int | None, int]: - """关键位门控:按 base 查 24h 成交额全市场排名;cache 带 TTL。""" - cached_ok = ( - cache.get("version") == cache_version - and cache.get("updated_at") - and now_ts - float(cache["updated_at"]) < ttl_sec - ) - if not cached_ok: - try: - ranks, total = build_usdt_swap_volume_ranks( - exchange, - ensure_markets_loaded, - exchange_id=exchange_id, - ) - if total > 0 and ranks: - cache["ranks"] = ranks - cache["total"] = total - cache["version"] = cache_version - cache["updated_at"] = now_ts - except Exception: - pass - ranks = cache.get("ranks") or {} - total = int(cache.get("total") or 0) - base = str(target_base or "").strip().upper() - return ranks.get(base), total - - -def fetch_usdt_swap_volume_rank( - exchange, - ensure_markets_loaded: Callable[[], None], - *, - top_n: int = TOP_N_DEFAULT, - rank_date: str | None = None, - exchange_id: str | None = None, -) -> dict[str, Any]: - """从 ccxt 拉全市场 USDT 永续 ticker,按 24h 成交额(USDT) 取 Top N。""" - top_n = max(1, min(int(top_n or TOP_N_DEFAULT), 100)) - ensure_markets_loaded() - ex_id = str(exchange_id or getattr(exchange, "id", "") or "").lower() - - try: - scored = _collect_scores(exchange, ex_id) - except Exception as e: - return {"ok": False, "msg": str(e)} - - items = [] - for idx, (hub_sym, base, qv) in enumerate(scored[:top_n], 1): - items.append( - { - "rank": idx, - "symbol": hub_sym, - "base": base, - "volume_quote": round(qv, 4), - } - ) - return { - "ok": True, - "rank_date": rank_date or rank_date_label(), - "items": items, - "total_symbols": len(scored), - "exchange_id": ex_id, - "fetched_at": datetime.now(volume_rank_timezone()).isoformat(timespec="seconds"), - } - - -def format_volume_quote(value: float | None) -> str: - n = _safe_float(value) - if n is None or n <= 0: - return "—" - if n >= 1e9: - return f"{n / 1e9:.2f}B" - if n >= 1e6: - return f"{n / 1e6:.2f}M" - if n >= 1e3: - return f"{n / 1e3:.2f}K" - return f"{n:.0f}" - - -def load_volume_rank_cache(path: Path | None = None) -> dict[str, Any]: - p = path or default_cache_path() - if not p.is_file(): - return {"version": CACHE_VERSION, "exchanges": {}} - try: - data = json.loads(p.read_text(encoding="utf-8")) - if not isinstance(data, dict): - return {"version": CACHE_VERSION, "exchanges": {}} - if int(data.get("version") or 0) < CACHE_VERSION: - return {"version": CACHE_VERSION, "exchanges": {}} - data.setdefault("version", CACHE_VERSION) - data.setdefault("exchanges", {}) - return data - except Exception: - return {"version": CACHE_VERSION, "exchanges": {}} - - -def save_volume_rank_cache(data: dict[str, Any], path: Path | None = None) -> None: - p = path or default_cache_path() - p.parent.mkdir(parents=True, exist_ok=True) - payload = dict(data) - payload["version"] = CACHE_VERSION - payload["updated_at"] = datetime.now(volume_rank_timezone()).isoformat(timespec="seconds") - p.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") - - -def merge_exchange_rank( - cache: dict[str, Any], - exchange_key: str, - payload: dict[str, Any], -) -> dict[str, Any]: - ex_k = str(exchange_key or "").strip().lower() - if not ex_k or not payload.get("ok"): - return cache - exchanges = dict(cache.get("exchanges") or {}) - exchanges[ex_k] = { - "rank_date": payload.get("rank_date"), - "items": payload.get("items") or [], - "total_symbols": int(payload.get("total_symbols") or 0), - "fetched_at": payload.get("fetched_at"), - "error": None, - } - out = dict(cache) - out["exchanges"] = exchanges - out["rank_date"] = payload.get("rank_date") or cache.get("rank_date") - return out - - -def _exchange_rank_row_stale(row: dict[str, Any] | None) -> bool: - if not row: - return True - items = row.get("items") or [] - if len(items) < TOP_N_DEFAULT: - return True - total = int(row.get("total_symbols") or 0) - if total > 0 and total < TOP_N_DEFAULT: - return True - return False - - -def cache_needs_refresh( - cache: dict[str, Any], - *, - expected_rank_date: str | None = None, - required_keys: list[str] | None = None, -) -> bool: - expected = expected_rank_date or rank_date_label() - if int(cache.get("version") or 0) < CACHE_VERSION: - return True - exchanges = cache.get("exchanges") or {} - if not exchanges: - return True - if str(cache.get("rank_date") or "") != expected: - return True - keys = required_keys or list(exchanges.keys()) - if not keys: - return True - for key in keys: - ex_k = str(key or "").strip().lower() - if not ex_k: - continue - if _exchange_rank_row_stale(exchanges.get(ex_k)): - return True - return False - - -def get_cached_rank( - cache: dict[str, Any], - exchange_key: str, - *, - top_n: int = TOP_N_DEFAULT, -) -> dict[str, Any]: - ex_k = str(exchange_key or "").strip().lower() - ex_data = (cache.get("exchanges") or {}).get(ex_k) or {} - items = list(ex_data.get("items") or [])[: max(1, int(top_n))] - stale = _exchange_rank_row_stale(ex_data) - return { - "ok": True, - "exchange_key": ex_k, - "rank_date": ex_data.get("rank_date") or cache.get("rank_date"), - "updated_at": cache.get("updated_at"), - "items": items, - "item_count": len(items), - "expected_count": int(top_n), - "total_symbols": int(ex_data.get("total_symbols") or 0), - "stale": stale, - "error": ex_data.get("error"), - } +"""行情区:各交易所 USDT 永续昨日成交额 Top N(每日 8:00 快照)。""" + +from __future__ import annotations + +import json +import os +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, Callable +from zoneinfo import ZoneInfo + +from lib.hub.hub_trades_lib import trading_day_from_dt + +TOP_N_DEFAULT = 20 +CACHE_VERSION = 3 +LIQUIDITY_RANK_CACHE_VERSION = 1 + + +def volume_rank_reset_hour() -> int: + try: + return max(0, min(23, int(os.getenv("HUB_VOLUME_RANK_RESET_HOUR", "8")))) + except ValueError: + return 8 + + +def volume_rank_timezone() -> ZoneInfo: + name = (os.getenv("HUB_VOLUME_RANK_TZ") or "Asia/Shanghai").strip() or "Asia/Shanghai" + try: + return ZoneInfo(name) + except Exception: + return ZoneInfo("Asia/Shanghai") + + +def rank_date_label(*, now: datetime | None = None, reset_hour: int | None = None) -> str: + """8 点更新后展示的「昨日」交易日(与 TRADING_DAY_RESET_HOUR 口径一致)。""" + rh = volume_rank_reset_hour() if reset_hour is None else reset_hour + tz = volume_rank_timezone() + dt = now.astimezone(tz) if now else datetime.now(tz) + cur_td = trading_day_from_dt(dt.replace(tzinfo=None), rh) + cur = datetime.strptime(cur_td, "%Y-%m-%d").date() + return (cur - timedelta(days=1)).isoformat() + + +def seconds_until_next_reset( + *, + now: datetime | None = None, + reset_hour: int | None = None, +) -> float: + rh = volume_rank_reset_hour() if reset_hour is None else reset_hour + tz = volume_rank_timezone() + dt = now.astimezone(tz) if now else datetime.now(tz) + nxt = dt.replace(hour=rh, minute=0, second=0, microsecond=0) + if dt >= nxt: + nxt += timedelta(days=1) + return max(1.0, (nxt - dt).total_seconds()) + + +def default_cache_path() -> Path: + raw = (os.getenv("HUB_VOLUME_RANK_CACHE_PATH") or "").strip() + if raw: + return Path(raw) + hub_dir = Path(__file__).resolve().parent / "manual_trading_hub" / "data" + hub_dir.mkdir(parents=True, exist_ok=True) + return hub_dir / "hub_volume_rank.json" + + +def _safe_float(v: Any) -> float | None: + try: + n = float(v) + return n if n == n else None + except (TypeError, ValueError): + return None + + +def _ticker_base(sym_text: str) -> str: + s = str(sym_text or "").upper().strip() + if ":" in s: + s = s.split(":", 1)[0] + if "/" in s: + return s.split("/", 1)[0].strip() + if "-" in s: + return s.split("-", 1)[0].strip() + if s.endswith("USDT"): + return s[:-4].strip() + return s + + +def _hub_symbol_from_base(base: str, quote: str = "USDT") -> str: + b = str(base or "").strip().upper() + q = str(quote or "USDT").strip().upper() + return f"{b}/{q}" if b else "" + + +def _hub_symbol_from_market(market: dict | None, fallback_symbol: str) -> str: + if market: + base = str(market.get("base") or "").strip().upper() + quote = str(market.get("quote") or "USDT").strip().upper() + if base: + return f"{base}/{quote}" + fb = str(fallback_symbol or "").upper().strip() + if ":" in fb: + fb = fb.split(":", 1)[0] + if "/" in fb: + return fb + base = _ticker_base(fb) + return f"{base}/USDT" if base else fb + + +def _okx_turnover_usdt(row: dict | None) -> float | None: + """OKX SWAP:成交额(USDT) ≈ volCcy24h(基础币) × last。""" + if not isinstance(row, dict): + return None + base_vol = _safe_float(row.get("volCcy24h")) + if base_vol is None or base_vol <= 0: + return None + last = _safe_float(row.get("last") or row.get("lastPx")) + if last is None or last <= 0: + return None + return float(base_vol * last) + + +def _quote_volume_from_ticker( + ticker: dict | None, + market: dict | None, + *, + exchange_id: str = "", +) -> float | None: + ex_id = str(exchange_id or "").lower() + t = ticker or {} + info = t.get("info") if isinstance(t.get("info"), dict) else {} + + if ex_id == "okx": + row = dict(info) + if row.get("last") is None: + row["last"] = t.get("last") + qv = _okx_turnover_usdt(row) + if qv is not None and qv > 0: + return qv + + qv = _safe_float(t.get("quoteVolume")) + if qv is not None and qv > 0: + return qv + + if ex_id in ("gateio", "gate"): + for key in ( + "volume_24h_quote", + "volume_24h_settle", + "quote_volume", + "vol_24h", + "turnover", + ): + qv = _safe_float(info.get(key)) + if qv is not None and qv > 0: + return qv + + for key in ("quoteVolume", "volCcy24h", "vol24h", "turnover24h", "amount24", "turnover"): + qv = _safe_float(info.get(key)) + if qv is not None and qv > 0: + if key == "volCcy24h" and ex_id == "okx": + last = _safe_float(info.get("last") or info.get("lastPx") or t.get("last")) + if last: + return qv * last + return qv + + bv = _safe_float(t.get("baseVolume")) + lp = _safe_float(t.get("last")) or _safe_float(t.get("close")) + if bv is not None and lp is not None and bv > 0 and lp > 0: + return bv * lp + + if info: + bv = _safe_float(info.get("volCcy24h") or info.get("vol24h") or info.get("volume")) + lp = _safe_float(info.get("last") or info.get("lastPx") or info.get("markPrice")) + if bv is not None and lp is not None and bv > 0 and lp > 0: + return bv * lp + + return None + + +def _is_usdt_linear_swap(market: dict | None, symbol: str) -> bool: + if not market: + su = str(symbol or "").upper() + return "USDT" in su and (":USDT" in su or "/USDT" in su or su.endswith("USDT")) + if not market.get("swap") and market.get("type") not in ("swap", "future"): + return False + if str(market.get("quote") or "").upper() != "USDT": + return False + if market.get("linear") is False: + return False + if market.get("active") is False: + return False + settle = str(market.get("settle") or "").upper() + if settle and settle != "USDT": + return False + return True + + +def _lookup_ticker(tickers: dict, sym: str, market: dict | None) -> dict | None: + if not tickers: + return None + t = tickers.get(sym) + if t: + return t + if not market: + return None + base = market.get("base") + quote = market.get("quote") or "USDT" + settle = market.get("settle") or quote + candidates = [ + sym, + f"{base}/{quote}:{settle}", + f"{base}/{quote}", + f"{base}{quote}", + market.get("id"), + ] + for key in candidates: + if not key: + continue + t = tickers.get(key) + if t: + return t + return None + + +def _merge_scores(scored: dict[str, tuple[str, float]]) -> list[tuple[str, str, float]]: + rows = [(sym, base, vol) for base, (sym, vol) in scored.items() if sym and base and vol > 0] + rows.sort(key=lambda x: x[2], reverse=True) + return rows + + +def _scores_from_okx(exchange) -> list[tuple[str, str, float]]: + by_base: dict[str, tuple[str, float]] = {} + if hasattr(exchange, "publicGetMarketTickers"): + try: + resp = exchange.publicGetMarketTickers({"instType": "SWAP"}) + for row in (resp or {}).get("data") or []: + if not isinstance(row, dict): + continue + inst = str(row.get("instId") or "").upper() + parts = inst.split("-") + if len(parts) < 3 or parts[-1] != "SWAP" or parts[1] != "USDT": + continue + base = parts[0].strip() + if not base: + continue + qv = _okx_turnover_usdt(row) + if qv is None or qv <= 0: + continue + sym = _hub_symbol_from_base(base) + prev = by_base.get(base) + if prev is None or qv > prev[1]: + by_base[base] = (sym, float(qv)) + if by_base: + return _merge_scores(by_base) + except Exception: + pass + + try: + tickers = exchange.fetch_tickers(params={"instType": "SWAP"}) + except Exception: + tickers = exchange.fetch_tickers() + return _scores_from_markets(exchange, tickers or {}, "okx") + + +def _scores_from_binance(exchange) -> list[tuple[str, str, float]]: + by_base: dict[str, tuple[str, float]] = {} + if hasattr(exchange, "fapiPublicGetTicker24hr"): + try: + rows = exchange.fapiPublicGetTicker24hr() + if isinstance(rows, list): + for row in rows: + if not isinstance(row, dict): + continue + raw = str(row.get("symbol") or "").upper() + if not raw.endswith("USDT"): + continue + base = raw[:-4] + if not base: + continue + qv = _safe_float(row.get("quoteVolume")) + if qv is None or qv <= 0: + bv = _safe_float(row.get("volume")) + lp = _safe_float(row.get("lastPrice") or row.get("weightedAvgPrice")) + if bv and lp: + qv = bv * lp + if qv is None or qv <= 0: + continue + sym = _hub_symbol_from_base(base) + prev = by_base.get(base) + if prev is None or qv > prev[1]: + by_base[base] = (sym, float(qv)) + if by_base: + return _merge_scores(by_base) + except Exception: + pass + return [] + + +def _scores_from_gate(exchange) -> list[tuple[str, str, float]]: + by_base: dict[str, tuple[str, float]] = {} + for method_name in ("publicFuturesGetSettleTickers", "publicFuturesGetUsdtTickers"): + fn = getattr(exchange, method_name, None) + if not callable(fn): + continue + try: + rows = fn({"settle": "usdt"}) + if isinstance(rows, list): + for row in rows: + if not isinstance(row, dict): + continue + contract = str(row.get("contract") or row.get("name") or "").upper() + if not contract: + continue + base = contract.replace("_USDT", "").replace("USDT", "").strip("_") + if not base: + continue + qv = _safe_float(row.get("volume_24h_quote") or row.get("volume_24h_settle")) + if qv is None or qv <= 0: + bv = _safe_float(row.get("volume_24h_base")) + lp = _safe_float(row.get("last") or row.get("mark_price")) + if bv and lp: + qv = bv * lp + if qv is None or qv <= 0: + continue + sym = _hub_symbol_from_base(base) + prev = by_base.get(base) + if prev is None or qv > prev[1]: + by_base[base] = (sym, float(qv)) + if by_base: + return _merge_scores(by_base) + except Exception: + continue + return [] + + +def _scores_from_markets( + exchange, + tickers: dict, + exchange_id: str, +) -> list[tuple[str, str, float]]: + by_base: dict[str, tuple[str, float]] = {} + markets = getattr(exchange, "markets", None) or {} + for sym, mk in markets.items(): + try: + if not _is_usdt_linear_swap(mk, sym): + continue + ticker = _lookup_ticker(tickers, sym, mk) + qv = _quote_volume_from_ticker(ticker, mk, exchange_id=exchange_id) + if qv is None or qv <= 0: + continue + hub_sym = _hub_symbol_from_market(mk, sym) + base = _ticker_base(hub_sym) + if not base: + continue + prev = by_base.get(base) + if prev is None or qv > prev[1]: + by_base[base] = (hub_sym, float(qv)) + except Exception: + continue + return _merge_scores(by_base) + + +def _collect_scores(exchange, exchange_id: str) -> list[tuple[str, str, float]]: + ex_id = str(exchange_id or "").lower() + if ex_id == "okx": + return _scores_from_okx(exchange) + if ex_id == "binance": + return _scores_from_binance(exchange) + if ex_id in ("gateio", "gate", "gate_bot"): + return _scores_from_gate(exchange) + tickers = exchange.fetch_tickers() + return _scores_from_markets(exchange, tickers or {}, ex_id) + + +def _uses_lightweight_volume_scores(exchange_id: str) -> bool: + ex_id = str(exchange_id or "").lower() + return ex_id in ("okx", "binance", "gateio", "gate", "gate_bot") + + +def build_usdt_swap_volume_ranks( + exchange, + ensure_markets_loaded: Callable[[], None], + *, + exchange_id: str | None = None, +) -> tuple[dict[str, int], int]: + """ + 全市场 USDT 永续 24h 成交额排名(base -> rank)。 + 优先各所轻量 ticker API,避免 fetch_tickers() 拉全市场(Gate/Binance 内存优化)。 + """ + ex_id = str(exchange_id or getattr(exchange, "id", "") or "").lower() + if not _uses_lightweight_volume_scores(ex_id): + ensure_markets_loaded() + scored = _collect_scores(exchange, ex_id) + ranks: dict[str, int] = {} + for idx, (_sym, base, _qv) in enumerate(scored, 1): + if base and base not in ranks: + ranks[base] = idx + return ranks, len(scored) + + +def resolve_daily_volume_rank( + target_base: str, + cache: dict[str, Any], + *, + now_ts: float, + ttl_sec: float, + exchange, + ensure_markets_loaded: Callable[[], None], + exchange_id: str | None = None, + cache_version: int = LIQUIDITY_RANK_CACHE_VERSION, +) -> tuple[int | None, int]: + """关键位门控:按 base 查 24h 成交额全市场排名;cache 带 TTL。""" + cached_ok = ( + cache.get("version") == cache_version + and cache.get("updated_at") + and now_ts - float(cache["updated_at"]) < ttl_sec + ) + if not cached_ok: + try: + ranks, total = build_usdt_swap_volume_ranks( + exchange, + ensure_markets_loaded, + exchange_id=exchange_id, + ) + if total > 0 and ranks: + cache["ranks"] = ranks + cache["total"] = total + cache["version"] = cache_version + cache["updated_at"] = now_ts + except Exception: + pass + ranks = cache.get("ranks") or {} + total = int(cache.get("total") or 0) + base = str(target_base or "").strip().upper() + return ranks.get(base), total + + +def fetch_usdt_swap_volume_rank( + exchange, + ensure_markets_loaded: Callable[[], None], + *, + top_n: int = TOP_N_DEFAULT, + rank_date: str | None = None, + exchange_id: str | None = None, +) -> dict[str, Any]: + """从 ccxt 拉全市场 USDT 永续 ticker,按 24h 成交额(USDT) 取 Top N。""" + top_n = max(1, min(int(top_n or TOP_N_DEFAULT), 100)) + ensure_markets_loaded() + ex_id = str(exchange_id or getattr(exchange, "id", "") or "").lower() + + try: + scored = _collect_scores(exchange, ex_id) + except Exception as e: + return {"ok": False, "msg": str(e)} + + items = [] + for idx, (hub_sym, base, qv) in enumerate(scored[:top_n], 1): + items.append( + { + "rank": idx, + "symbol": hub_sym, + "base": base, + "volume_quote": round(qv, 4), + } + ) + return { + "ok": True, + "rank_date": rank_date or rank_date_label(), + "items": items, + "total_symbols": len(scored), + "exchange_id": ex_id, + "fetched_at": datetime.now(volume_rank_timezone()).isoformat(timespec="seconds"), + } + + +def format_volume_quote(value: float | None) -> str: + n = _safe_float(value) + if n is None or n <= 0: + return "—" + if n >= 1e9: + return f"{n / 1e9:.2f}B" + if n >= 1e6: + return f"{n / 1e6:.2f}M" + if n >= 1e3: + return f"{n / 1e3:.2f}K" + return f"{n:.0f}" + + +def load_volume_rank_cache(path: Path | None = None) -> dict[str, Any]: + p = path or default_cache_path() + if not p.is_file(): + return {"version": CACHE_VERSION, "exchanges": {}} + try: + data = json.loads(p.read_text(encoding="utf-8")) + if not isinstance(data, dict): + return {"version": CACHE_VERSION, "exchanges": {}} + if int(data.get("version") or 0) < CACHE_VERSION: + return {"version": CACHE_VERSION, "exchanges": {}} + data.setdefault("version", CACHE_VERSION) + data.setdefault("exchanges", {}) + return data + except Exception: + return {"version": CACHE_VERSION, "exchanges": {}} + + +def save_volume_rank_cache(data: dict[str, Any], path: Path | None = None) -> None: + p = path or default_cache_path() + p.parent.mkdir(parents=True, exist_ok=True) + payload = dict(data) + payload["version"] = CACHE_VERSION + payload["updated_at"] = datetime.now(volume_rank_timezone()).isoformat(timespec="seconds") + p.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") + + +def merge_exchange_rank( + cache: dict[str, Any], + exchange_key: str, + payload: dict[str, Any], +) -> dict[str, Any]: + ex_k = str(exchange_key or "").strip().lower() + if not ex_k or not payload.get("ok"): + return cache + exchanges = dict(cache.get("exchanges") or {}) + exchanges[ex_k] = { + "rank_date": payload.get("rank_date"), + "items": payload.get("items") or [], + "total_symbols": int(payload.get("total_symbols") or 0), + "fetched_at": payload.get("fetched_at"), + "error": None, + } + out = dict(cache) + out["exchanges"] = exchanges + out["rank_date"] = payload.get("rank_date") or cache.get("rank_date") + return out + + +def _exchange_rank_row_stale(row: dict[str, Any] | None) -> bool: + if not row: + return True + items = row.get("items") or [] + if len(items) < TOP_N_DEFAULT: + return True + total = int(row.get("total_symbols") or 0) + if total > 0 and total < TOP_N_DEFAULT: + return True + return False + + +def cache_needs_refresh( + cache: dict[str, Any], + *, + expected_rank_date: str | None = None, + required_keys: list[str] | None = None, +) -> bool: + expected = expected_rank_date or rank_date_label() + if int(cache.get("version") or 0) < CACHE_VERSION: + return True + exchanges = cache.get("exchanges") or {} + if not exchanges: + return True + if str(cache.get("rank_date") or "") != expected: + return True + keys = required_keys or list(exchanges.keys()) + if not keys: + return True + for key in keys: + ex_k = str(key or "").strip().lower() + if not ex_k: + continue + if _exchange_rank_row_stale(exchanges.get(ex_k)): + return True + return False + + +def get_cached_rank( + cache: dict[str, Any], + exchange_key: str, + *, + top_n: int = TOP_N_DEFAULT, +) -> dict[str, Any]: + ex_k = str(exchange_key or "").strip().lower() + ex_data = (cache.get("exchanges") or {}).get(ex_k) or {} + items = list(ex_data.get("items") or [])[: max(1, int(top_n))] + stale = _exchange_rank_row_stale(ex_data) + return { + "ok": True, + "exchange_key": ex_k, + "rank_date": ex_data.get("rank_date") or cache.get("rank_date"), + "updated_at": cache.get("updated_at"), + "items": items, + "item_count": len(items), + "expected_count": int(top_n), + "total_symbols": int(ex_data.get("total_symbols") or 0), + "stale": stale, + "error": ex_data.get("error"), + } diff --git a/lib/instance/__init__.py b/lib/instance/__init__.py new file mode 100644 index 0000000..ab164b5 --- /dev/null +++ b/lib/instance/__init__.py @@ -0,0 +1 @@ +"""Shared library package.""" diff --git a/focus_chart_lib.py b/lib/instance/focus_chart_lib.py similarity index 95% rename from focus_chart_lib.py rename to lib/instance/focus_chart_lib.py index 4da1cf5..e62b847 100644 --- a/focus_chart_lib.py +++ b/lib/instance/focus_chart_lib.py @@ -1,187 +1,187 @@ -"""实盘/关键位放大 K 线:订单元数据与交易所浮盈、价格展示精度。""" -from __future__ import annotations - -from typing import Any, Callable, Optional - -from hub_ohlcv_lib import ( - normalize_price_tick, - price_tick_from_market, - round_ohlcv_bars_to_tick, -) -from order_monitor_display_lib import ( - apply_order_live_price_display, - apply_order_price_display_fields, -) - - -def resolve_kline_price_tick( - exchange: Any, - exchange_symbol: str, - *, - ensure_markets_fn: Callable[[], None], -) -> Optional[float]: - """交易所最小价格变动单位,供 lightweight-charts 右侧刻度与标记线对齐。""" - if not exchange_symbol: - return None - try: - ensure_markets_fn() - return normalize_price_tick(price_tick_from_market(exchange, exchange_symbol)) - except Exception: - return None - - -def align_candles_to_price_tick( - candles: list[dict[str, Any]], - price_tick: Optional[float], -) -> None: - if price_tick is not None and candles: - round_ohlcv_bars_to_tick(candles, price_tick) - - -def kline_api_price_fields( - exchange: Any, - exchange_symbol: str, - candles: list[dict[str, Any]], - *, - ensure_markets_fn: Callable[[], None], -) -> dict[str, Any]: - tick = resolve_kline_price_tick( - exchange, exchange_symbol, ensure_markets_fn=ensure_markets_fn - ) - align_candles_to_price_tick(candles, tick) - return {"price_tick": tick} - - -def load_swap_positions_for_order_kline( - exchange: Any, - *, - private_configured: bool, - ensure_markets_fn: Callable[[], None], - settle: str = "usdt", -) -> list: - if not private_configured: - return [] - try: - ensure_markets_fn() - try: - return exchange.fetch_positions(None, {"settle": settle}) or [] - except Exception: - return exchange.fetch_positions() or [] - except Exception: - return [] - - -def metrics_for_order_item( - order_item: dict[str, Any], - positions: list, - *, - resolve_ex_sym_fn: Callable[[Any], str], - select_live_fn: Callable[[list, str, str], Any], - parse_metrics_fn: Callable[..., Optional[dict]], -) -> Optional[dict]: - if not positions: - return None - ex_sym = resolve_ex_sym_fn(order_item) - direction = order_item.get("direction") or "long" - prow = select_live_fn(positions, ex_sym, direction) - if not prow: - return None - lev = order_item.get("leverage") - return parse_metrics_fn(prow, order_leverage=lev) - - -def build_order_kline_order_payload( - order_item: dict[str, Any], - *, - ticker_price: Any, - format_price_fn: Callable[[Any, Any], str], - calc_pnl_fn: Callable[..., float], - calc_rr_ratio_fn: Callable[..., Optional[float]], - ex_metrics: Optional[dict] = None, -) -> dict[str, Any]: - sym = order_item.get("symbol") or "" - direction = order_item.get("direction") or "long" - margin = float(order_item.get("margin_capital") or 0) - leverage = float(order_item.get("leverage") or 0) - entry = float(order_item.get("trigger_price") or 0) - - float_pnl = 0.0 - float_pct = 0.0 - if ticker_price and entry > 0: - float_pnl = float( - calc_pnl_fn(direction, entry, ticker_price, margin, leverage) - ) - float_pct = round((float_pnl / margin * 100), 4) if margin > 0 else 0.0 - - px_for_fmt = ticker_price - mark_raw = None - if ex_metrics and ex_metrics.get("mark_price") is not None: - mark_raw = ex_metrics["mark_price"] - try: - px_for_fmt = float(mark_raw) - except (TypeError, ValueError): - pass - - if ex_metrics and ex_metrics.get("unrealized_pnl") is not None: - float_pnl = round(float(ex_metrics["unrealized_pnl"]), 2) - denom = ex_metrics.get("initial_margin") or margin - float_pct = ( - round((float_pnl / float(denom)) * 100, 4) - if denom and float(denom) > 0 - else float_pct - ) - - payload: dict[str, Any] = { - "id": order_item["id"], - "symbol": sym, - "direction": direction, - "trigger_price": order_item.get("trigger_price"), - "stop_loss": order_item.get("stop_loss"), - "take_profit": order_item.get("take_profit"), - "trigger_price_display": format_price_fn(sym, order_item.get("trigger_price")), - "stop_loss_display": format_price_fn(sym, order_item.get("stop_loss")), - "take_profit_display": format_price_fn(sym, order_item.get("take_profit")), - "margin_capital": order_item.get("margin_capital"), - "leverage": order_item.get("leverage"), - "position_ratio": order_item.get("position_ratio"), - "breakeven_enabled": bool(int(order_item.get("breakeven_enabled") or 0)), - "current_price": round(float(px_for_fmt), 8) if px_for_fmt is not None else None, - "float_pnl": round(float(float_pnl), 2), - "float_pct": float_pct, - } - apply_order_price_display_fields( - payload, - direction=direction, - entry_price=order_item.get("trigger_price"), - initial_stop_loss=order_item.get("initial_stop_loss"), - stop_loss=order_item.get("stop_loss"), - take_profit=order_item.get("take_profit"), - calc_rr_ratio_fn=calc_rr_ratio_fn, - ) - apply_order_live_price_display( - payload, - sym, - ticker_price, - mark_raw, - format_price_fn, - ) - payload["current_price_display"] = payload.get("price_display") or ( - format_price_fn(sym, px_for_fmt) if px_for_fmt is not None else None - ) - return payload - - -def enrich_key_kline_response( - *, - symbol: str, - current_price: Any, - key_info: Optional[dict[str, Any]], - format_price_fn: Callable[[Any, Any], str], -) -> tuple[Any, Optional[dict[str, Any]]]: - price_display = format_price_fn(symbol, current_price) if current_price is not None else None - if key_info is None: - return price_display, None - enriched = dict(key_info) - enriched["upper_display"] = format_price_fn(symbol, key_info.get("upper")) - enriched["lower_display"] = format_price_fn(symbol, key_info.get("lower")) - return price_display, enriched +"""实盘/关键位放大 K 线:订单元数据与交易所浮盈、价格展示精度。""" +from __future__ import annotations + +from typing import Any, Callable, Optional + +from lib.hub.hub_ohlcv_lib import ( + normalize_price_tick, + price_tick_from_market, + round_ohlcv_bars_to_tick, +) +from lib.trade.order_monitor_display_lib import ( + apply_order_live_price_display, + apply_order_price_display_fields, +) + + +def resolve_kline_price_tick( + exchange: Any, + exchange_symbol: str, + *, + ensure_markets_fn: Callable[[], None], +) -> Optional[float]: + """交易所最小价格变动单位,供 lightweight-charts 右侧刻度与标记线对齐。""" + if not exchange_symbol: + return None + try: + ensure_markets_fn() + return normalize_price_tick(price_tick_from_market(exchange, exchange_symbol)) + except Exception: + return None + + +def align_candles_to_price_tick( + candles: list[dict[str, Any]], + price_tick: Optional[float], +) -> None: + if price_tick is not None and candles: + round_ohlcv_bars_to_tick(candles, price_tick) + + +def kline_api_price_fields( + exchange: Any, + exchange_symbol: str, + candles: list[dict[str, Any]], + *, + ensure_markets_fn: Callable[[], None], +) -> dict[str, Any]: + tick = resolve_kline_price_tick( + exchange, exchange_symbol, ensure_markets_fn=ensure_markets_fn + ) + align_candles_to_price_tick(candles, tick) + return {"price_tick": tick} + + +def load_swap_positions_for_order_kline( + exchange: Any, + *, + private_configured: bool, + ensure_markets_fn: Callable[[], None], + settle: str = "usdt", +) -> list: + if not private_configured: + return [] + try: + ensure_markets_fn() + try: + return exchange.fetch_positions(None, {"settle": settle}) or [] + except Exception: + return exchange.fetch_positions() or [] + except Exception: + return [] + + +def metrics_for_order_item( + order_item: dict[str, Any], + positions: list, + *, + resolve_ex_sym_fn: Callable[[Any], str], + select_live_fn: Callable[[list, str, str], Any], + parse_metrics_fn: Callable[..., Optional[dict]], +) -> Optional[dict]: + if not positions: + return None + ex_sym = resolve_ex_sym_fn(order_item) + direction = order_item.get("direction") or "long" + prow = select_live_fn(positions, ex_sym, direction) + if not prow: + return None + lev = order_item.get("leverage") + return parse_metrics_fn(prow, order_leverage=lev) + + +def build_order_kline_order_payload( + order_item: dict[str, Any], + *, + ticker_price: Any, + format_price_fn: Callable[[Any, Any], str], + calc_pnl_fn: Callable[..., float], + calc_rr_ratio_fn: Callable[..., Optional[float]], + ex_metrics: Optional[dict] = None, +) -> dict[str, Any]: + sym = order_item.get("symbol") or "" + direction = order_item.get("direction") or "long" + margin = float(order_item.get("margin_capital") or 0) + leverage = float(order_item.get("leverage") or 0) + entry = float(order_item.get("trigger_price") or 0) + + float_pnl = 0.0 + float_pct = 0.0 + if ticker_price and entry > 0: + float_pnl = float( + calc_pnl_fn(direction, entry, ticker_price, margin, leverage) + ) + float_pct = round((float_pnl / margin * 100), 4) if margin > 0 else 0.0 + + px_for_fmt = ticker_price + mark_raw = None + if ex_metrics and ex_metrics.get("mark_price") is not None: + mark_raw = ex_metrics["mark_price"] + try: + px_for_fmt = float(mark_raw) + except (TypeError, ValueError): + pass + + if ex_metrics and ex_metrics.get("unrealized_pnl") is not None: + float_pnl = round(float(ex_metrics["unrealized_pnl"]), 2) + denom = ex_metrics.get("initial_margin") or margin + float_pct = ( + round((float_pnl / float(denom)) * 100, 4) + if denom and float(denom) > 0 + else float_pct + ) + + payload: dict[str, Any] = { + "id": order_item["id"], + "symbol": sym, + "direction": direction, + "trigger_price": order_item.get("trigger_price"), + "stop_loss": order_item.get("stop_loss"), + "take_profit": order_item.get("take_profit"), + "trigger_price_display": format_price_fn(sym, order_item.get("trigger_price")), + "stop_loss_display": format_price_fn(sym, order_item.get("stop_loss")), + "take_profit_display": format_price_fn(sym, order_item.get("take_profit")), + "margin_capital": order_item.get("margin_capital"), + "leverage": order_item.get("leverage"), + "position_ratio": order_item.get("position_ratio"), + "breakeven_enabled": bool(int(order_item.get("breakeven_enabled") or 0)), + "current_price": round(float(px_for_fmt), 8) if px_for_fmt is not None else None, + "float_pnl": round(float(float_pnl), 2), + "float_pct": float_pct, + } + apply_order_price_display_fields( + payload, + direction=direction, + entry_price=order_item.get("trigger_price"), + initial_stop_loss=order_item.get("initial_stop_loss"), + stop_loss=order_item.get("stop_loss"), + take_profit=order_item.get("take_profit"), + calc_rr_ratio_fn=calc_rr_ratio_fn, + ) + apply_order_live_price_display( + payload, + sym, + ticker_price, + mark_raw, + format_price_fn, + ) + payload["current_price_display"] = payload.get("price_display") or ( + format_price_fn(sym, px_for_fmt) if px_for_fmt is not None else None + ) + return payload + + +def enrich_key_kline_response( + *, + symbol: str, + current_price: Any, + key_info: Optional[dict[str, Any]], + format_price_fn: Callable[[Any, Any], str], +) -> tuple[Any, Optional[dict[str, Any]]]: + price_display = format_price_fn(symbol, current_price) if current_price is not None else None + if key_info is None: + return price_display, None + enriched = dict(key_info) + enriched["upper_display"] = format_price_fn(symbol, key_info.get("upper")) + enriched["lower_display"] = format_price_fn(symbol, key_info.get("lower")) + return price_display, enriched diff --git a/instance_embed_context_lib.py b/lib/instance/instance_embed_context_lib.py similarity index 94% rename from instance_embed_context_lib.py rename to lib/instance/instance_embed_context_lib.py index dc682ca..188b5a3 100644 --- a/instance_embed_context_lib.py +++ b/lib/instance/instance_embed_context_lib.py @@ -1,84 +1,84 @@ -"""embed 壳/片段:按 tab 裁剪 render_main_page 的数据加载,降内存与 API 压力。""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any - -EMBED_STRATEGY_PAGES = frozenset({"strategy", "strategy_trend", "strategy_roll", "strategy_records"}) - - -@dataclass(frozen=True) -class EmbedRenderPlan: - exchange_capitals: bool - records_rows: bool - records_summary: bool - key_history: bool - key_list: bool - orders: bool - stats_bundle: bool - strategy: bool - orphan_live: bool - - -def embed_render_plan(page: str, embed_mode: str | None) -> EmbedRenderPlan: - if embed_mode not in ("fragment", "shell"): - return EmbedRenderPlan( - exchange_capitals=True, - records_rows=True, - records_summary=False, - key_history=True, - key_list=True, - orders=True, - stats_bundle=True, - strategy=True, - orphan_live=True, - ) - is_shell = embed_mode == "shell" - is_strategy = page in EMBED_STRATEGY_PAGES - return EmbedRenderPlan( - exchange_capitals=is_shell, - records_rows=page == "records", - records_summary=is_shell and page != "records", - key_history=page == "key_monitor", - key_list=page in ("key_monitor", "trade") or is_strategy, - orders=page == "trade" or is_strategy, - stats_bundle=page == "stats", - strategy=is_strategy, - orphan_live=page == "trade" and is_shell, - ) - - -def trade_records_summary(conn, start_bj: str, end_bj: str, tr_ts: str) -> dict[str, Any]: - """顶栏统计用 COUNT,避免 embed 壳拉 1000 行交易记录。""" - from trade_result_lib import sql_effective_pnl_expr - - pnl_sql = sql_effective_pnl_expr() - row = conn.execute( - f""" - SELECT - COUNT(*) AS total, - SUM(CASE WHEN result = '错过' THEN 1 ELSE 0 END) AS miss_count, - SUM(CASE WHEN {pnl_sql} > 0 THEN 1 ELSE 0 END) AS wins, - SUM(CASE WHEN result = '错过' AND COALESCE(miss_reason,'') LIKE '%持仓占用%' THEN 1 ELSE 0 END) AS occupied_miss - FROM trade_records - WHERE {tr_ts} >= ? AND {tr_ts} <= ? - """, - (start_bj, end_bj), - ).fetchone() - total = int(row["total"] or 0) if row else 0 - miss_count = int(row["miss_count"] or 0) if row else 0 - wins = int(row["wins"] or 0) if row else 0 - occupied_miss_total = int(row["occupied_miss"] or 0) if row else 0 - rate = round(wins / total * 100, 2) if total else 0 - return { - "records": [], - "total": total, - "miss_count": miss_count, - "rate": rate, - "occupied_miss_total": occupied_miss_total, - } - - -def minimal_stats_bundle(reset_hour: int) -> dict[str, Any]: - return {"stats_reset_hour": reset_hour, "segments": []} +"""embed 壳/片段:按 tab 裁剪 render_main_page 的数据加载,降内存与 API 压力。""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +EMBED_STRATEGY_PAGES = frozenset({"strategy", "strategy_trend", "strategy_roll", "strategy_records"}) + + +@dataclass(frozen=True) +class EmbedRenderPlan: + exchange_capitals: bool + records_rows: bool + records_summary: bool + key_history: bool + key_list: bool + orders: bool + stats_bundle: bool + strategy: bool + orphan_live: bool + + +def embed_render_plan(page: str, embed_mode: str | None) -> EmbedRenderPlan: + if embed_mode not in ("fragment", "shell"): + return EmbedRenderPlan( + exchange_capitals=True, + records_rows=True, + records_summary=False, + key_history=True, + key_list=True, + orders=True, + stats_bundle=True, + strategy=True, + orphan_live=True, + ) + is_shell = embed_mode == "shell" + is_strategy = page in EMBED_STRATEGY_PAGES + return EmbedRenderPlan( + exchange_capitals=is_shell, + records_rows=page == "records", + records_summary=is_shell and page != "records", + key_history=page == "key_monitor", + key_list=page in ("key_monitor", "trade") or is_strategy, + orders=page == "trade" or is_strategy, + stats_bundle=page == "stats", + strategy=is_strategy, + orphan_live=page == "trade" and is_shell, + ) + + +def trade_records_summary(conn, start_bj: str, end_bj: str, tr_ts: str) -> dict[str, Any]: + """顶栏统计用 COUNT,避免 embed 壳拉 1000 行交易记录。""" + from lib.trade.trade_result_lib import sql_effective_pnl_expr + + pnl_sql = sql_effective_pnl_expr() + row = conn.execute( + f""" + SELECT + COUNT(*) AS total, + SUM(CASE WHEN result = '错过' THEN 1 ELSE 0 END) AS miss_count, + SUM(CASE WHEN {pnl_sql} > 0 THEN 1 ELSE 0 END) AS wins, + SUM(CASE WHEN result = '错过' AND COALESCE(miss_reason,'') LIKE '%持仓占用%' THEN 1 ELSE 0 END) AS occupied_miss + FROM trade_records + WHERE {tr_ts} >= ? AND {tr_ts} <= ? + """, + (start_bj, end_bj), + ).fetchone() + total = int(row["total"] or 0) if row else 0 + miss_count = int(row["miss_count"] or 0) if row else 0 + wins = int(row["wins"] or 0) if row else 0 + occupied_miss_total = int(row["occupied_miss"] or 0) if row else 0 + rate = round(wins / total * 100, 2) if total else 0 + return { + "records": [], + "total": total, + "miss_count": miss_count, + "rate": rate, + "occupied_miss_total": occupied_miss_total, + } + + +def minimal_stats_bundle(reset_hour: int) -> dict[str, Any]: + return {"stats_reset_hour": reset_hour, "segments": []} diff --git a/instance_embed_lib.py b/lib/instance/instance_embed_lib.py similarity index 95% rename from instance_embed_lib.py rename to lib/instance/instance_embed_lib.py index f8ebf72..3afc085 100644 --- a/instance_embed_lib.py +++ b/lib/instance/instance_embed_lib.py @@ -1,147 +1,148 @@ -"""中控 iframe:壳常驻 + tab 内容 API(/embed、/api/embed/page/)。""" - -from __future__ import annotations - -import os -from typing import Callable -from urllib.parse import parse_qsl, urlencode, urlsplit - -from flask import Flask, Response, jsonify, redirect, request, session -from jinja2 import ChoiceLoader, FileSystemLoader - -EMBED_TABS: tuple[str, ...] = ( - "key_monitor", - "trade", - "strategy", - "strategy_records", - "records", - "stats", -) - -PATH_TO_EMBED_TAB: dict[str, str] = { - "/": "trade", - "/trade": "trade", - "/key_monitor": "key_monitor", - "/strategy": "strategy", - "/strategy/trend": "strategy", - "/strategy/roll": "strategy", - "/strategy/records": "strategy_records", - "/records": "records", - "/stats": "stats", -} - -ORDER_RULE_TIPS_BY_EXCHANGE: dict[str, str] = { - "gate": "order_monitor_rule_tips_gate.html", - "gate_bot": "order_monitor_rule_tips_gate.html", - "binance": "order_monitor_rule_tips_binance.html", - "okx": "order_monitor_rule_tips_okx.html", -} - - -def order_rule_tips_template(exchange_key: str) -> str: - ex = (exchange_key or "").strip().lower() - return ORDER_RULE_TIPS_BY_EXCHANGE.get(ex, "order_monitor_rule_tips_gate.html") - - -def include_transfer_block(exchange_key: str) -> bool: - return (exchange_key or "").strip().lower() in ("gate", "gate_bot") - - -def path_to_embed_tab(path: str) -> str | None: - p = (path or "/").strip() - if not p.startswith("/"): - p = "/" + p - base = urlsplit(p).path.rstrip("/") or "/" - return PATH_TO_EMBED_TAB.get(base) - - -def embed_shell_enabled() -> bool: - return (os.getenv("HUB_EMBED_SHELL") or "1").strip().lower() in ("1", "true", "yes", "on") - - -def rewrite_embed_dest(path: str, hub_theme: str | None = None) -> str: - """embed=1 打开时:/trade → /embed?tab=trade&embed=1""" - if not embed_shell_enabled(): - split = urlsplit(path or "/") - q = dict(parse_qsl(split.query, keep_blank_values=True)) - q["embed"] = "1" - ht = (hub_theme or q.get("hub_theme") or "").strip().lower() - if ht in ("light", "dark"): - q["hub_theme"] = ht - dest = split.path or "/" - if q: - return f"{dest}?{urlencode(q)}" - return dest + "?embed=1" - split = urlsplit(path or "/") - tab = path_to_embed_tab(split.path) - q = dict(parse_qsl(split.query, keep_blank_values=True)) - if tab: - q["tab"] = tab - q["embed"] = "1" - ht = (hub_theme or q.get("hub_theme") or "").strip().lower() - if ht in ("light", "dark"): - q["hub_theme"] = ht - return f"/embed?{urlencode(q)}" - q["embed"] = "1" - ht = (hub_theme or q.get("hub_theme") or "").strip().lower() - if ht in ("light", "dark"): - q["hub_theme"] = ht - dest = split.path or "/" - if split.query: - dest += "?" + split.query - if "embed=1" not in dest: - sep = "&" if "?" in dest else "?" - dest += f"{sep}embed=1" - if ht in ("light", "dark") and "hub_theme=" not in dest: - sep = "&" if "?" in dest else "?" - dest += f"{sep}hub_theme={ht}" - return dest - - -def attach_embed_templates(app: Flask, repo_root: str) -> None: - embed_dir = os.path.join(repo_root, "embed_templates") - if not os.path.isdir(embed_dir): - return - existing = app.jinja_loader - loaders = [FileSystemLoader(embed_dir)] - if existing is not None: - if isinstance(existing, ChoiceLoader): - loaders = list(existing.loaders) + loaders - else: - loaders.insert(0, existing) - app.jinja_loader = ChoiceLoader(loaders) - - -def register_embed_routes( - app: Flask, - login_required: Callable, - render_main_page_fn: Callable, -) -> None: - app.config["RENDER_MAIN_PAGE_FN"] = render_main_page_fn - - @login_required - @app.route("/embed") - def embed_shell_page(): - tab = (request.args.get("tab") or "trade").strip() - if tab not in EMBED_TABS: - tab = "trade" - session["hub_embed_shell"] = True - return render_main_page_fn(tab, embed_mode="shell") - - @login_required - @app.route("/api/embed/page/") - def api_embed_page(tab: str): - tab = (tab or "").strip() - if tab not in EMBED_TABS: - return jsonify({"ok": False, "msg": "unknown tab"}), 404 - html = render_main_page_fn(tab, embed_mode="fragment") - if isinstance(html, Response): - html = html.get_data(as_text=True) - return jsonify({"ok": True, "page": tab, "html": html}) - - -def embed_context_extras(exchange_key: str) -> dict: - return { - "order_rule_tips_tpl": order_rule_tips_template(exchange_key), - "include_transfer_block": include_transfer_block(exchange_key), - } +"""中控 iframe:壳常驻 + tab 内容 API(/embed、/api/embed/page/)。""" +from __future__ import annotations + +from lib.paths import embed_templates_dir + +import os +from typing import Callable +from urllib.parse import parse_qsl, urlencode, urlsplit + +from flask import Flask, Response, jsonify, redirect, request, session +from jinja2 import ChoiceLoader, FileSystemLoader + +EMBED_TABS: tuple[str, ...] = ( + "key_monitor", + "trade", + "strategy", + "strategy_records", + "records", + "stats", +) + +PATH_TO_EMBED_TAB: dict[str, str] = { + "/": "trade", + "/trade": "trade", + "/key_monitor": "key_monitor", + "/strategy": "strategy", + "/strategy/trend": "strategy", + "/strategy/roll": "strategy", + "/strategy/records": "strategy_records", + "/records": "records", + "/stats": "stats", +} + +ORDER_RULE_TIPS_BY_EXCHANGE: dict[str, str] = { + "gate": "order_monitor_rule_tips_gate.html", + "gate_bot": "order_monitor_rule_tips_gate.html", + "binance": "order_monitor_rule_tips_binance.html", + "okx": "order_monitor_rule_tips_okx.html", +} + + +def order_rule_tips_template(exchange_key: str) -> str: + ex = (exchange_key or "").strip().lower() + return ORDER_RULE_TIPS_BY_EXCHANGE.get(ex, "order_monitor_rule_tips_gate.html") + + +def include_transfer_block(exchange_key: str) -> bool: + return (exchange_key or "").strip().lower() in ("gate", "gate_bot") + + +def path_to_embed_tab(path: str) -> str | None: + p = (path or "/").strip() + if not p.startswith("/"): + p = "/" + p + base = urlsplit(p).path.rstrip("/") or "/" + return PATH_TO_EMBED_TAB.get(base) + + +def embed_shell_enabled() -> bool: + return (os.getenv("HUB_EMBED_SHELL") or "1").strip().lower() in ("1", "true", "yes", "on") + + +def rewrite_embed_dest(path: str, hub_theme: str | None = None) -> str: + """embed=1 打开时:/trade → /embed?tab=trade&embed=1""" + if not embed_shell_enabled(): + split = urlsplit(path or "/") + q = dict(parse_qsl(split.query, keep_blank_values=True)) + q["embed"] = "1" + ht = (hub_theme or q.get("hub_theme") or "").strip().lower() + if ht in ("light", "dark"): + q["hub_theme"] = ht + dest = split.path or "/" + if q: + return f"{dest}?{urlencode(q)}" + return dest + "?embed=1" + split = urlsplit(path or "/") + tab = path_to_embed_tab(split.path) + q = dict(parse_qsl(split.query, keep_blank_values=True)) + if tab: + q["tab"] = tab + q["embed"] = "1" + ht = (hub_theme or q.get("hub_theme") or "").strip().lower() + if ht in ("light", "dark"): + q["hub_theme"] = ht + return f"/embed?{urlencode(q)}" + q["embed"] = "1" + ht = (hub_theme or q.get("hub_theme") or "").strip().lower() + if ht in ("light", "dark"): + q["hub_theme"] = ht + dest = split.path or "/" + if split.query: + dest += "?" + split.query + if "embed=1" not in dest: + sep = "&" if "?" in dest else "?" + dest += f"{sep}embed=1" + if ht in ("light", "dark") and "hub_theme=" not in dest: + sep = "&" if "?" in dest else "?" + dest += f"{sep}hub_theme={ht}" + return dest + + +def attach_embed_templates(app: Flask, repo_root: str) -> None: + embed_dir = embed_templates_dir(repo_root) + if not os.path.isdir(embed_dir): + return + existing = app.jinja_loader + loaders = [FileSystemLoader(embed_dir)] + if existing is not None: + if isinstance(existing, ChoiceLoader): + loaders = list(existing.loaders) + loaders + else: + loaders.insert(0, existing) + app.jinja_loader = ChoiceLoader(loaders) + + +def register_embed_routes( + app: Flask, + login_required: Callable, + render_main_page_fn: Callable, +) -> None: + app.config["RENDER_MAIN_PAGE_FN"] = render_main_page_fn + + @login_required + @app.route("/embed") + def embed_shell_page(): + tab = (request.args.get("tab") or "trade").strip() + if tab not in EMBED_TABS: + tab = "trade" + session["hub_embed_shell"] = True + return render_main_page_fn(tab, embed_mode="shell") + + @login_required + @app.route("/api/embed/page/") + def api_embed_page(tab: str): + tab = (tab or "").strip() + if tab not in EMBED_TABS: + return jsonify({"ok": False, "msg": "unknown tab"}), 404 + html = render_main_page_fn(tab, embed_mode="fragment") + if isinstance(html, Response): + html = html.get_data(as_text=True) + return jsonify({"ok": True, "page": tab, "html": html}) + + +def embed_context_extras(exchange_key: str) -> dict: + return { + "order_rule_tips_tpl": order_rule_tips_template(exchange_key), + "include_transfer_block": include_transfer_block(exchange_key), + } diff --git a/instance_nav_lib.py b/lib/instance/instance_nav_lib.py similarity index 100% rename from instance_nav_lib.py rename to lib/instance/instance_nav_lib.py diff --git a/journal_chart_lib.py b/lib/instance/journal_chart_lib.py similarity index 100% rename from journal_chart_lib.py rename to lib/instance/journal_chart_lib.py diff --git a/embed_templates/embed_boot_scripts.html b/lib/instance/templates/embed_boot_scripts.html similarity index 100% rename from embed_templates/embed_boot_scripts.html rename to lib/instance/templates/embed_boot_scripts.html diff --git a/embed_templates/embed_page_fragment.html b/lib/instance/templates/embed_page_fragment.html similarity index 100% rename from embed_templates/embed_page_fragment.html rename to lib/instance/templates/embed_page_fragment.html diff --git a/embed_templates/embed_shell.html b/lib/instance/templates/embed_shell.html similarity index 100% rename from embed_templates/embed_shell.html rename to lib/instance/templates/embed_shell.html diff --git a/lib/key_monitor/__init__.py b/lib/key_monitor/__init__.py new file mode 100644 index 0000000..ab164b5 --- /dev/null +++ b/lib/key_monitor/__init__.py @@ -0,0 +1 @@ +"""Shared library package.""" diff --git a/false_breakout_key_monitor_lib.py b/lib/key_monitor/false_breakout_key_monitor_lib.py similarity index 95% rename from false_breakout_key_monitor_lib.py rename to lib/key_monitor/false_breakout_key_monitor_lib.py index f8393a8..3a51104 100644 --- a/false_breakout_key_monitor_lib.py +++ b/lib/key_monitor/false_breakout_key_monitor_lib.py @@ -1,145 +1,145 @@ -"""假突破关键位监控:BTC/ETH 限价挂单(共享计算与校验)。""" -from __future__ import annotations - -from datetime import datetime, timedelta -from typing import Any, Optional - -FALSE_BREAKOUT_MONITOR_TYPE = "假突破" -FALSE_BREAKOUT_SYMBOLS = frozenset({"BTC/USDT", "ETH/USDT"}) -FALSE_BREAKOUT_OFFSET_PCT = 0.1 -FALSE_BREAKOUT_SL_PCT = 0.5 -FALSE_BREAKOUT_RR = 1.5 -FALSE_BREAKOUT_VALIDITY_HOURS = 24 - - -def is_false_breakout_key_monitor_type(monitor_type: Optional[str]) -> bool: - return (monitor_type or "").strip() == FALSE_BREAKOUT_MONITOR_TYPE - - -def is_limit_key_monitor_type(monitor_type: Optional[str]) -> bool: - from fib_key_monitor_lib import is_fib_key_monitor_type - - return is_fib_key_monitor_type(monitor_type) or is_false_breakout_key_monitor_type(monitor_type) - - -def normalize_false_breakout_symbol(symbol: Optional[str]) -> Optional[str]: - s = (symbol or "").strip().upper() - if not s: - return None - if "/" not in s: - s = f"{s}/USDT" - return s if s in FALSE_BREAKOUT_SYMBOLS else None - - -def storage_bounds_from_key_price(direction: str, key_price: float) -> tuple[float, float]: - k = float(key_price) - if k <= 0: - raise ValueError("关键价位须为正数") - d = (direction or "long").strip().lower() - if d == "short": - return k, k * 0.9999 - if d == "long": - return k * 1.0001, k - raise ValueError("方向须为 long 或 short") - - -def key_price_from_row(direction: str, upper: Any, lower: Any) -> Optional[float]: - d = (direction or "long").strip().lower() - try: - if d == "short": - v = float(upper) - else: - v = float(lower) - except (TypeError, ValueError): - return None - return v if v > 0 else None - - -def calc_false_breakout_plan(direction: str, key_price: float) -> Optional[tuple[float, float, float]]: - try: - k = float(key_price) - except (TypeError, ValueError): - return None - if k <= 0: - return None - d = (direction or "long").strip().lower() - off = FALSE_BREAKOUT_OFFSET_PCT / 100.0 - sl_pct = FALSE_BREAKOUT_SL_PCT / 100.0 - rr = float(FALSE_BREAKOUT_RR) - if d == "short": - entry = k * (1 + off) - sl = entry * (1 + sl_pct) - risk = sl - entry - if risk <= 0: - return None - tp = entry - risk * rr - return entry, sl, tp - if d == "long": - entry = k * (1 - off) - sl = entry * (1 - sl_pct) - risk = entry - sl - if risk <= 0: - return None - tp = entry + risk * rr - return entry, sl, tp - return None - - -def _parse_created_at(raw: Any) -> Optional[datetime]: - s = str(raw or "").strip() - if not s: - return None - for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S"): - try: - return datetime.strptime(s[:26], fmt) - except ValueError: - continue - try: - return datetime.fromisoformat(s.replace("Z", "+00:00")[:32]) - except ValueError: - return None - - -def is_false_breakout_expired( - created_at: Any, - now: datetime, - *, - hours: int = FALSE_BREAKOUT_VALIDITY_HOURS, -) -> bool: - dt = _parse_created_at(created_at) - if dt is None: - return False - return now >= dt + timedelta(hours=hours) - - -def expires_at_text(created_at: Any, *, hours: int = FALSE_BREAKOUT_VALIDITY_HOURS) -> str: - dt = _parse_created_at(created_at) - if dt is None: - return "—" - return (dt + timedelta(hours=hours)).strftime("%Y-%m-%d %H:%M:%S") - - -def false_breakout_gate_preview( - *, - entry_display: str, - limit_order_id: Any = None, - created_at: Any = None, - now: Optional[datetime] = None, - hours: int = FALSE_BREAKOUT_VALIDITY_HOURS, -) -> dict[str, Any]: - """假突破门控预览:限价挂单状态,不使用箱体/收敛的量破幅二确门控。""" - now_dt = now or datetime.now() - expired = is_false_breakout_expired(created_at, now_dt, hours=hours) - exp_txt = expires_at_text(created_at, hours=hours) - status = "已过期" if expired else "等待成交" - metrics_parts: list[str] = [] - oid = str(limit_order_id or "").strip() - if oid: - metrics_parts.append(f"限价单:{oid}") - if exp_txt != "—": - metrics_parts.append(f"截至:{exp_txt}") - return { - "summary": f"假突破 挂E={entry_display} {status}", - "metrics": " ".join(metrics_parts), - "gate_ok": not expired, - } +"""假突破关键位监控:BTC/ETH 限价挂单(共享计算与校验)。""" +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import Any, Optional + +FALSE_BREAKOUT_MONITOR_TYPE = "假突破" +FALSE_BREAKOUT_SYMBOLS = frozenset({"BTC/USDT", "ETH/USDT"}) +FALSE_BREAKOUT_OFFSET_PCT = 0.1 +FALSE_BREAKOUT_SL_PCT = 0.5 +FALSE_BREAKOUT_RR = 1.5 +FALSE_BREAKOUT_VALIDITY_HOURS = 24 + + +def is_false_breakout_key_monitor_type(monitor_type: Optional[str]) -> bool: + return (monitor_type or "").strip() == FALSE_BREAKOUT_MONITOR_TYPE + + +def is_limit_key_monitor_type(monitor_type: Optional[str]) -> bool: + from lib.key_monitor.fib_key_monitor_lib import is_fib_key_monitor_type + + return is_fib_key_monitor_type(monitor_type) or is_false_breakout_key_monitor_type(monitor_type) + + +def normalize_false_breakout_symbol(symbol: Optional[str]) -> Optional[str]: + s = (symbol or "").strip().upper() + if not s: + return None + if "/" not in s: + s = f"{s}/USDT" + return s if s in FALSE_BREAKOUT_SYMBOLS else None + + +def storage_bounds_from_key_price(direction: str, key_price: float) -> tuple[float, float]: + k = float(key_price) + if k <= 0: + raise ValueError("关键价位须为正数") + d = (direction or "long").strip().lower() + if d == "short": + return k, k * 0.9999 + if d == "long": + return k * 1.0001, k + raise ValueError("方向须为 long 或 short") + + +def key_price_from_row(direction: str, upper: Any, lower: Any) -> Optional[float]: + d = (direction or "long").strip().lower() + try: + if d == "short": + v = float(upper) + else: + v = float(lower) + except (TypeError, ValueError): + return None + return v if v > 0 else None + + +def calc_false_breakout_plan(direction: str, key_price: float) -> Optional[tuple[float, float, float]]: + try: + k = float(key_price) + except (TypeError, ValueError): + return None + if k <= 0: + return None + d = (direction or "long").strip().lower() + off = FALSE_BREAKOUT_OFFSET_PCT / 100.0 + sl_pct = FALSE_BREAKOUT_SL_PCT / 100.0 + rr = float(FALSE_BREAKOUT_RR) + if d == "short": + entry = k * (1 + off) + sl = entry * (1 + sl_pct) + risk = sl - entry + if risk <= 0: + return None + tp = entry - risk * rr + return entry, sl, tp + if d == "long": + entry = k * (1 - off) + sl = entry * (1 - sl_pct) + risk = entry - sl + if risk <= 0: + return None + tp = entry + risk * rr + return entry, sl, tp + return None + + +def _parse_created_at(raw: Any) -> Optional[datetime]: + s = str(raw or "").strip() + if not s: + return None + for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S"): + try: + return datetime.strptime(s[:26], fmt) + except ValueError: + continue + try: + return datetime.fromisoformat(s.replace("Z", "+00:00")[:32]) + except ValueError: + return None + + +def is_false_breakout_expired( + created_at: Any, + now: datetime, + *, + hours: int = FALSE_BREAKOUT_VALIDITY_HOURS, +) -> bool: + dt = _parse_created_at(created_at) + if dt is None: + return False + return now >= dt + timedelta(hours=hours) + + +def expires_at_text(created_at: Any, *, hours: int = FALSE_BREAKOUT_VALIDITY_HOURS) -> str: + dt = _parse_created_at(created_at) + if dt is None: + return "—" + return (dt + timedelta(hours=hours)).strftime("%Y-%m-%d %H:%M:%S") + + +def false_breakout_gate_preview( + *, + entry_display: str, + limit_order_id: Any = None, + created_at: Any = None, + now: Optional[datetime] = None, + hours: int = FALSE_BREAKOUT_VALIDITY_HOURS, +) -> dict[str, Any]: + """假突破门控预览:限价挂单状态,不使用箱体/收敛的量破幅二确门控。""" + now_dt = now or datetime.now() + expired = is_false_breakout_expired(created_at, now_dt, hours=hours) + exp_txt = expires_at_text(created_at, hours=hours) + status = "已过期" if expired else "等待成交" + metrics_parts: list[str] = [] + oid = str(limit_order_id or "").strip() + if oid: + metrics_parts.append(f"限价单:{oid}") + if exp_txt != "—": + metrics_parts.append(f"截至:{exp_txt}") + return { + "summary": f"假突破 挂E={entry_display} {status}", + "metrics": " ".join(metrics_parts), + "gate_ok": not expired, + } diff --git a/fib_key_monitor_lib.py b/lib/key_monitor/fib_key_monitor_lib.py similarity index 96% rename from fib_key_monitor_lib.py rename to lib/key_monitor/fib_key_monitor_lib.py index ac53577..46df17c 100644 --- a/fib_key_monitor_lib.py +++ b/lib/key_monitor/fib_key_monitor_lib.py @@ -1,140 +1,140 @@ -"""斐波关键位监控:纯计算与类型判断(Gate / Binance 主站共用)。""" - -from key_monitor_lib import KEY_MONITOR_AUTO_TYPES - -FIB_KEY_MONITOR_TYPES = frozenset({"斐波回调0.618", "斐波回调0.786"}) -KEY_MONITOR_TRADE_TYPE = "关键位监控" - -FIB_RATIO_BY_TYPE = { - "斐波回调0.618": 0.618, - "斐波回调0.786": 0.786, -} - - -def is_fib_key_monitor_type(monitor_type): - return (monitor_type or "").strip() in FIB_KEY_MONITOR_TYPES - - -def fib_ratio_from_type(monitor_type): - return FIB_RATIO_BY_TYPE.get((monitor_type or "").strip()) - - -def calc_fib_plan(direction, upper, lower, ratio): - """ - 上沿 H、下沿 L(H > L)。 - 做多:自 H 向下回撤 ratio,E = H - ratio*(H-L);SL=L,TP=H。 - 做空:自 L 向上反弹 ratio,E = L + ratio*(H-L);SL=H,TP=L。 - 返回 (entry, stop_loss, take_profit) 或 None。 - """ - try: - h = float(upper) - l = float(lower) - r = float(ratio) - except (TypeError, ValueError): - return None - if h <= l or r <= 0 or r >= 1: - return None - span = h - l - direction = (direction or "long").strip().lower() - if direction == "short": - entry = l + r * span - return entry, h, l - entry = h - r * span - return entry, l, h - - -def stored_key_signal_type(monitor_type): - """写入 order_monitors / trade_records 的 key_signal_type(箱体/收敛/斐波/假突破/触价开仓)。""" - mt = (monitor_type or "").strip() - if mt in FIB_KEY_MONITOR_TYPES: - return mt - if mt in ("假突破", "回调触价开仓", "突破触价开仓", "触价开仓"): - return mt if mt != "触价开仓" else "回调触价开仓" - if mt in KEY_MONITOR_AUTO_TYPES: - return mt - return None - - -KEY_ENTRY_REASON_BY_SIGNAL = { - "箱体突破": "关键位箱体突破", - "收敛突破": "关键位收敛突破", - "斐波回调0.618": "关键位斐波0.618", - "斐波回调0.786": "关键位斐波0.786", - "假突破": "关键位假突破", - "回调触价开仓": "关键位回调触价开仓", - "突破触价开仓": "关键位突破触价开仓", - "触价开仓": "关键位触价开仓", - "趋势回调": "趋势回调", -} - - -def entry_reason_from_key_signal(key_signal_type): - return KEY_ENTRY_REASON_BY_SIGNAL.get((key_signal_type or "").strip()) - - -def key_signal_type_for_trade_record(key_signal_type, box_auto_types): - """平仓写入 trade_records 时保留箱体/收敛/斐波/假突破来源。""" - kst = (key_signal_type or "").strip() - if kst in FIB_KEY_MONITOR_TYPES: - return kst - if kst in ("假突破", "回调触价开仓", "突破触价开仓", "触价开仓"): - return kst if kst != "触价开仓" else "回调触价开仓" - if box_auto_types and kst in box_auto_types: - return kst - return None - - -def backfill_missing_key_signal_types(conn, *, monitor_type: str = KEY_MONITOR_TRADE_TYPE) -> int: - """补全历史 trade_records / order_monitors 中缺失的箱体/收敛 key_signal_type。""" - mt = (monitor_type or KEY_MONITOR_TRADE_TYPE).strip() - updated = 0 - for signal in KEY_MONITOR_AUTO_TYPES: - entry_reason = KEY_ENTRY_REASON_BY_SIGNAL.get(signal) - if entry_reason: - cur = conn.execute( - """UPDATE trade_records SET key_signal_type=? - WHERE monitor_type=? AND (key_signal_type IS NULL OR TRIM(key_signal_type)='') - AND TRIM(COALESCE(entry_reason, ''))=?""", - (signal, mt, entry_reason), - ) - updated += int(cur.rowcount or 0) - rows = conn.execute( - """SELECT id, symbol, opened_at FROM trade_records - WHERE monitor_type=? AND (key_signal_type IS NULL OR TRIM(key_signal_type)='')""", - (mt,), - ).fetchall() - for row in rows: - # init_db 连接未设 row_factory,结果为 tuple - rid, sym, opened_at = row[0], row[1], row[2] - opened = (opened_at or "").strip() - for signal in KEY_MONITOR_AUTO_TYPES: - hist = conn.execute( - """SELECT monitor_type FROM key_monitor_history - WHERE symbol=? AND monitor_type=? AND close_reason='auto_opened' - AND (?='' OR closed_at <= ?) - ORDER BY closed_at DESC LIMIT 1""", - (sym, signal, opened, opened), - ).fetchone() - if not hist: - continue - conn.execute( - "UPDATE trade_records SET key_signal_type=? WHERE id=?", - (signal, rid), - ) - updated += 1 - break - return updated - - -def fib_invalidate_by_mark(direction, mark_price, upper, lower): - """先触达止盈侧(标记价)则失效。多:mark>=H;空:mark<=L。""" - try: - m = float(mark_price) - h = float(upper) - l = float(lower) - except (TypeError, ValueError): - return False - direction = (direction or "long").strip().lower() - if direction == "short": - return m <= l - return m >= h +"""斐波关键位监控:纯计算与类型判断(Gate / Binance 主站共用)。""" + +from lib.key_monitor.key_monitor_lib import KEY_MONITOR_AUTO_TYPES + +FIB_KEY_MONITOR_TYPES = frozenset({"斐波回调0.618", "斐波回调0.786"}) +KEY_MONITOR_TRADE_TYPE = "关键位监控" + +FIB_RATIO_BY_TYPE = { + "斐波回调0.618": 0.618, + "斐波回调0.786": 0.786, +} + + +def is_fib_key_monitor_type(monitor_type): + return (monitor_type or "").strip() in FIB_KEY_MONITOR_TYPES + + +def fib_ratio_from_type(monitor_type): + return FIB_RATIO_BY_TYPE.get((monitor_type or "").strip()) + + +def calc_fib_plan(direction, upper, lower, ratio): + """ + 上沿 H、下沿 L(H > L)。 + 做多:自 H 向下回撤 ratio,E = H - ratio*(H-L);SL=L,TP=H。 + 做空:自 L 向上反弹 ratio,E = L + ratio*(H-L);SL=H,TP=L。 + 返回 (entry, stop_loss, take_profit) 或 None。 + """ + try: + h = float(upper) + l = float(lower) + r = float(ratio) + except (TypeError, ValueError): + return None + if h <= l or r <= 0 or r >= 1: + return None + span = h - l + direction = (direction or "long").strip().lower() + if direction == "short": + entry = l + r * span + return entry, h, l + entry = h - r * span + return entry, l, h + + +def stored_key_signal_type(monitor_type): + """写入 order_monitors / trade_records 的 key_signal_type(箱体/收敛/斐波/假突破/触价开仓)。""" + mt = (monitor_type or "").strip() + if mt in FIB_KEY_MONITOR_TYPES: + return mt + if mt in ("假突破", "回调触价开仓", "突破触价开仓", "触价开仓"): + return mt if mt != "触价开仓" else "回调触价开仓" + if mt in KEY_MONITOR_AUTO_TYPES: + return mt + return None + + +KEY_ENTRY_REASON_BY_SIGNAL = { + "箱体突破": "关键位箱体突破", + "收敛突破": "关键位收敛突破", + "斐波回调0.618": "关键位斐波0.618", + "斐波回调0.786": "关键位斐波0.786", + "假突破": "关键位假突破", + "回调触价开仓": "关键位回调触价开仓", + "突破触价开仓": "关键位突破触价开仓", + "触价开仓": "关键位触价开仓", + "趋势回调": "趋势回调", +} + + +def entry_reason_from_key_signal(key_signal_type): + return KEY_ENTRY_REASON_BY_SIGNAL.get((key_signal_type or "").strip()) + + +def key_signal_type_for_trade_record(key_signal_type, box_auto_types): + """平仓写入 trade_records 时保留箱体/收敛/斐波/假突破来源。""" + kst = (key_signal_type or "").strip() + if kst in FIB_KEY_MONITOR_TYPES: + return kst + if kst in ("假突破", "回调触价开仓", "突破触价开仓", "触价开仓"): + return kst if kst != "触价开仓" else "回调触价开仓" + if box_auto_types and kst in box_auto_types: + return kst + return None + + +def backfill_missing_key_signal_types(conn, *, monitor_type: str = KEY_MONITOR_TRADE_TYPE) -> int: + """补全历史 trade_records / order_monitors 中缺失的箱体/收敛 key_signal_type。""" + mt = (monitor_type or KEY_MONITOR_TRADE_TYPE).strip() + updated = 0 + for signal in KEY_MONITOR_AUTO_TYPES: + entry_reason = KEY_ENTRY_REASON_BY_SIGNAL.get(signal) + if entry_reason: + cur = conn.execute( + """UPDATE trade_records SET key_signal_type=? + WHERE monitor_type=? AND (key_signal_type IS NULL OR TRIM(key_signal_type)='') + AND TRIM(COALESCE(entry_reason, ''))=?""", + (signal, mt, entry_reason), + ) + updated += int(cur.rowcount or 0) + rows = conn.execute( + """SELECT id, symbol, opened_at FROM trade_records + WHERE monitor_type=? AND (key_signal_type IS NULL OR TRIM(key_signal_type)='')""", + (mt,), + ).fetchall() + for row in rows: + # init_db 连接未设 row_factory,结果为 tuple + rid, sym, opened_at = row[0], row[1], row[2] + opened = (opened_at or "").strip() + for signal in KEY_MONITOR_AUTO_TYPES: + hist = conn.execute( + """SELECT monitor_type FROM key_monitor_history + WHERE symbol=? AND monitor_type=? AND close_reason='auto_opened' + AND (?='' OR closed_at <= ?) + ORDER BY closed_at DESC LIMIT 1""", + (sym, signal, opened, opened), + ).fetchone() + if not hist: + continue + conn.execute( + "UPDATE trade_records SET key_signal_type=? WHERE id=?", + (signal, rid), + ) + updated += 1 + break + return updated + + +def fib_invalidate_by_mark(direction, mark_price, upper, lower): + """先触达止盈侧(标记价)则失效。多:mark>=H;空:mark<=L。""" + try: + m = float(mark_price) + h = float(upper) + l = float(lower) + except (TypeError, ValueError): + return False + direction = (direction or "long").strip().lower() + if direction == "short": + return m <= l + return m >= h diff --git a/key_monitor_full_margin_lib.py b/lib/key_monitor/key_monitor_full_margin_lib.py similarity index 83% rename from key_monitor_full_margin_lib.py rename to lib/key_monitor/key_monitor_full_margin_lib.py index 7b785e9..00a2494 100644 --- a/key_monitor_full_margin_lib.py +++ b/lib/key_monitor/key_monitor_full_margin_lib.py @@ -1,61 +1,61 @@ -""" -全仓杠杆模式下:撤销已添加的箱体/收敛/斐波关键位监控并微信说明。 -""" -from __future__ import annotations - -from typing import Any, Callable, Iterable, Optional - -from fib_key_monitor_lib import FIB_KEY_MONITOR_TYPES, is_fib_key_monitor_type -from false_breakout_key_monitor_lib import is_false_breakout_key_monitor_type -from key_monitor_lib import KEY_MONITOR_AUTO_TYPES -from position_sizing_lib import is_full_margin_mode, mode_label_zh - - -def monitor_type_disallowed_in_full_margin(monitor_type: str) -> bool: - mt = (monitor_type or "").strip() - if mt in KEY_MONITOR_AUTO_TYPES: - return True - if is_fib_key_monitor_type(mt): - return True - return is_false_breakout_key_monitor_type(mt) - - -def purge_disallowed_key_monitors( - conn: Any, - *, - sizing_mode: str, - select_rows: Callable[[Any], Iterable[Any]], - cancel_fib_limit: Callable[[Any], None], - delete_monitor: Callable[[Any, int], None], - send_wechat: Callable[[str], None], - row_symbol: Callable[[Any], str] = lambda r: str(r["symbol"] or ""), - row_monitor_type: Callable[[Any], str] = lambda r: str(r["monitor_type"] or ""), - row_id: Callable[[Any], int] = lambda r: int(r["id"]), -) -> int: - if not is_full_margin_mode(sizing_mode): - return 0 - removed = [] - for row in select_rows(conn): - mt = row_monitor_type(row) - if not monitor_type_disallowed_in_full_margin(mt): - continue - sym = row_symbol(row) - kid = row_id(row) - if is_fib_key_monitor_type(mt) or is_false_breakout_key_monitor_type(mt): - try: - cancel_fib_limit(row) - except Exception: - pass - delete_monitor(conn, kid) - removed.append((sym, mt, kid)) - if removed: - lines = [f"· {s} {t} (#{i})" for s, t, i in removed[:12]] - if len(removed) > 12: - lines.append(f"… 共 {len(removed)} 条") - send_wechat( - "# ⚠️ 全仓杠杆模式:已自动撤销关键位监控\n" - f"计仓模式:{mode_label_zh(sizing_mode)}(仅 env 可切换,须无仓)\n" - "已撤销:箱体突破 / 收敛突破 / 斐波回调 / 假突破监控(不可与全仓杠杆并存)\n" - + "\n".join(lines) - ) - return len(removed) +""" +全仓杠杆模式下:撤销已添加的箱体/收敛/斐波关键位监控并微信说明。 +""" +from __future__ import annotations + +from typing import Any, Callable, Iterable, Optional + +from lib.key_monitor.fib_key_monitor_lib import FIB_KEY_MONITOR_TYPES, is_fib_key_monitor_type +from lib.key_monitor.false_breakout_key_monitor_lib import is_false_breakout_key_monitor_type +from lib.key_monitor.key_monitor_lib import KEY_MONITOR_AUTO_TYPES +from lib.trade.position_sizing_lib import is_full_margin_mode, mode_label_zh + + +def monitor_type_disallowed_in_full_margin(monitor_type: str) -> bool: + mt = (monitor_type or "").strip() + if mt in KEY_MONITOR_AUTO_TYPES: + return True + if is_fib_key_monitor_type(mt): + return True + return is_false_breakout_key_monitor_type(mt) + + +def purge_disallowed_key_monitors( + conn: Any, + *, + sizing_mode: str, + select_rows: Callable[[Any], Iterable[Any]], + cancel_fib_limit: Callable[[Any], None], + delete_monitor: Callable[[Any, int], None], + send_wechat: Callable[[str], None], + row_symbol: Callable[[Any], str] = lambda r: str(r["symbol"] or ""), + row_monitor_type: Callable[[Any], str] = lambda r: str(r["monitor_type"] or ""), + row_id: Callable[[Any], int] = lambda r: int(r["id"]), +) -> int: + if not is_full_margin_mode(sizing_mode): + return 0 + removed = [] + for row in select_rows(conn): + mt = row_monitor_type(row) + if not monitor_type_disallowed_in_full_margin(mt): + continue + sym = row_symbol(row) + kid = row_id(row) + if is_fib_key_monitor_type(mt) or is_false_breakout_key_monitor_type(mt): + try: + cancel_fib_limit(row) + except Exception: + pass + delete_monitor(conn, kid) + removed.append((sym, mt, kid)) + if removed: + lines = [f"· {s} {t} (#{i})" for s, t, i in removed[:12]] + if len(removed) > 12: + lines.append(f"… 共 {len(removed)} 条") + send_wechat( + "# ⚠️ 全仓杠杆模式:已自动撤销关键位监控\n" + f"计仓模式:{mode_label_zh(sizing_mode)}(仅 env 可切换,须无仓)\n" + "已撤销:箱体突破 / 收敛突破 / 斐波回调 / 假突破监控(不可与全仓杠杆并存)\n" + + "\n".join(lines) + ) + return len(removed) diff --git a/key_monitor_lib.py b/lib/key_monitor/key_monitor_lib.py similarity index 95% rename from key_monitor_lib.py rename to lib/key_monitor/key_monitor_lib.py index 37324e5..e16e342 100644 --- a/key_monitor_lib.py +++ b/lib/key_monitor/key_monitor_lib.py @@ -1,390 +1,390 @@ -""" -关键位监控:阻力/支撑双向提醒与箱体/收敛自动门控的共享逻辑。 -""" -from __future__ import annotations - -from datetime import datetime -from typing import Any, Optional - -KEY_MONITOR_AUTO_TYPES = frozenset({"箱体突破", "收敛突破"}) -KEY_MONITOR_RS_TYPE = "关键支撑阻力" -KEY_MONITOR_RS_LEGACY_TYPES = frozenset({"关键阻力位", "关键支撑位"}) -KEY_MONITOR_RS_TYPES = frozenset({KEY_MONITOR_RS_TYPE}) | KEY_MONITOR_RS_LEGACY_TYPES -KEY_MONITOR_ALERT_ONLY_TYPES = frozenset({KEY_MONITOR_RS_TYPE}) | KEY_MONITOR_RS_LEGACY_TYPES -KEY_DIRECTION_WATCH = "watch" - - -def is_rs_key_monitor_type(monitor_type: str) -> bool: - return (monitor_type or "").strip() in KEY_MONITOR_RS_TYPES - - -def rs_monitor_type_label(monitor_type: str) -> str: - """展示用:旧库里的阻力/支撑合并为「关键支撑阻力」。""" - if is_rs_key_monitor_type(monitor_type): - return KEY_MONITOR_RS_TYPE - return (monitor_type or "").strip() - - -def rs_monitor_type_for_storage(monitor_type: str) -> str: - if is_rs_key_monitor_type(monitor_type): - return KEY_MONITOR_RS_TYPE - return (monitor_type or "").strip() - - -def calc_breakout_breach_pct(direction: str, close: float, upper: float, lower: float) -> float: - """突破 K 收盘相对关键位的越过幅度(%)。未越过对应边界时返回 0。""" - direction = (direction or "long").strip().lower() - c = float(close) - if direction == "long": - u = float(upper) - if u <= 0 or c <= u: - return 0.0 - return (c - u) / u * 100.0 - lo = float(lower) - if lo <= 0 or c >= lo: - return 0.0 - return (lo - c) / lo * 100.0 - - -def auto_amp_ok( - direction: str, - close_b: float, - upper: float, - lower: float, - min_pct: float, -) -> tuple[bool, float]: - breach = calc_breakout_breach_pct(direction, close_b, upper, lower) - return breach > float(min_pct), breach - - -def auto_confirm_ok(direction: str, cfm_close: float, upper: float, lower: float) -> bool: - """确认 K 收盘须在箱体外(不得回到 [lower, upper] 内)。""" - direction = (direction or "long").strip().lower() - c = float(cfm_close) - if direction == "long": - return c > float(upper) - return c < float(lower) - - -BOX_BREAKOUT_CLOSE_OPPOSITE = "box_opposite_break" - - -def box_breakout_invalidate_by_mark( - direction: str, mark_price: float, upper: float, lower: float -) -> bool: - """箱体/收敛:标记价先突破反向边界则失效。多:mark<=L;空:mark>=H。""" - try: - m = float(mark_price) - h = float(upper) - lo = float(lower) - except (TypeError, ValueError): - return False - direction = (direction or "long").strip().lower() - if direction == "short": - return m >= h - return m <= lo - - -def box_breakout_invalidate_edge_label(direction: str) -> str: - direction = (direction or "long").strip().lower() - return "下沿" if direction == "long" else "上沿" - - -def detect_rs_box_break(close: float, upper: float, lower: float) -> Optional[dict[str, Any]]: - """ - 阻力/支撑人工盯盘:最近 5m 收盘突破上沿或下沿(严格 > / <)。 - 上沿优先:同一根 K 不可能同时满足两者。 - """ - u, lo, c = float(upper), float(lower), float(close) - if c > u: - return { - "break_side": "upper", - "direction": "long", - "edge_price": u, - "key_price": u, - "break_label": "向上突破上沿", - } - if c < lo: - return { - "break_side": "lower", - "direction": "short", - "edge_price": lo, - "key_price": lo, - "break_label": "向下突破下沿", - } - return None - - -def rs_break_from_direction(direction: str, upper: float, lower: float) -> Optional[dict[str, Any]]: - """已触发后根据入库方向还原突破边(long=上沿,short=下沿)。""" - d = (direction or "").strip().lower() - if d == "long": - return { - "break_side": "upper", - "direction": "long", - "edge_price": float(upper), - "key_price": float(upper), - "break_label": "向上突破上沿", - } - if d == "short": - return { - "break_side": "lower", - "direction": "short", - "edge_price": float(lower), - "key_price": float(lower), - "break_label": "向下突破下沿", - } - return None - - -def rs_break_infer_from_close(close: float, upper: float, lower: float) -> dict[str, Any]: - """ - 续发提醒时价格已回到箱体内:按收盘价相对箱体中线推断首次突破边, - 保证第 2/3 次企业微信提醒仍能发出。 - """ - mid = (float(upper) + float(lower)) / 2.0 - if float(close) >= mid: - br = rs_break_from_direction("long", upper, lower) - else: - br = rs_break_from_direction("short", upper, lower) - if br: - return br - return { - "break_side": "upper", - "direction": "long", - "edge_price": float(upper), - "key_price": float(upper), - "break_label": "向上突破上沿", - } - - -def _parse_notify_datetime(raw: Optional[str]) -> Optional[datetime]: - s = str(raw or "").strip() - if not s: - return None - try: - dt = datetime.fromisoformat(s.replace("Z", "+00:00")) - if dt.tzinfo is not None: - dt = dt.replace(tzinfo=None) - return dt - except Exception: - pass - for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"): - try: - return datetime.strptime(s[:19], fmt) - except Exception: - continue - return None - - -def claim_rs_level_notify( - conn: Any, - monitor_id: int, - notify_index: int, - direction: str, - notified_at: str, - bar_ts: Optional[int], - *, - prior_count: Optional[int] = None, -) -> bool: - """ - 原子占位:仅在 notification_count 仍为 prior_count 时推进到 notify_index。 - 须在发送企业微信之前调用并 commit,避免 (2/3) 重复刷屏。 - """ - prior = int(prior_count if prior_count is not None else notify_index - 1) - if prior < 0 or notify_index != prior + 1: - return False - bar_val: Optional[int] = None - if bar_ts is not None: - try: - bar_val = int(bar_ts) - except (TypeError, ValueError): - bar_val = None - cur = conn.execute( - "UPDATE key_monitors SET notification_count=?, direction=?, last_notified_at=?, last_rs_bar_ts=? " - "WHERE id=? AND COALESCE(notification_count,0)=?", - (notify_index, direction, notified_at, bar_val, int(monitor_id), prior), - ) - return int(cur.rowcount or 0) > 0 - - -def parse_last_rs_bar_ts(row: Any) -> Optional[int]: - if row is None: - return None - try: - keys = row.keys() if hasattr(row, "keys") else [] - except Exception: - keys = [] - raw = row["last_rs_bar_ts"] if "last_rs_bar_ts" in keys else None - if raw is None: - return None - try: - return int(raw) - except (TypeError, ValueError): - return None - - -def run_rs_level_alert_tick( - row: Any, - close: float, - bar_ts: Optional[int], - now_dt: datetime, - *, - default_max_notify: int, - default_interval_min: int, -) -> Optional[dict[str, Any]]: - """ - 判定本轮回合是否应推送阻力/支撑提醒。 - 首条:仅在新闭合 K 越线时触发;发送前须 claim_rs_level_notify 占位防轮询/多进程重复。 - """ - up, lo = float(row["upper"]), float(row["lower"]) - if up <= lo: - return None - count = int(row["notification_count"] or 0) - max_n = max(1, int(row["max_notify"] or default_max_notify)) - interval = max(1, int(row["notify_interval_min"] or default_interval_min)) - if count >= max_n: - return None - - bar_ts_i: Optional[int] = None - if bar_ts is not None: - try: - bar_ts_i = int(bar_ts) - except (TypeError, ValueError): - bar_ts_i = None - last_bar_i = parse_last_rs_bar_ts(row) - - if count == 0: - br = detect_rs_box_break(close, up, lo) - if not br: - return None - if bar_ts_i is not None and last_bar_i is not None and bar_ts_i == last_bar_i: - return None - return { - "break_info": br, - "notify_index": 1, - "prior_count": 0, - "notify_max": max_n, - "interval_min": interval, - "bar_ts": bar_ts_i, - } - - if not notify_interval_elapsed(row["last_notified_at"], interval, now_dt): - return None - br = resolve_rs_break_for_alert(count, row["direction"], close, up, lo) - if not br: - return None - return { - "break_info": br, - "notify_index": count + 1, - "prior_count": count, - "notify_max": max_n, - "interval_min": interval, - "bar_ts": bar_ts_i, - } - - -def resolve_rs_break_for_alert( - notification_count: int, - direction: Optional[str], - close: float, - upper: float, - lower: float, -) -> Optional[dict[str, Any]]: - """ - 阻力/支撑提醒:首次用 5m 收盘越线判定;后续用已存方向,兼容 direction=watch。 - """ - count = int(notification_count or 0) - up, lo, c = float(upper), float(lower), float(close) - if count <= 0: - return detect_rs_box_break(c, up, lo) - br = rs_break_from_direction(direction, up, lo) - if br: - return br - d = (direction or "").strip().lower() - if d not in ("", KEY_DIRECTION_WATCH): - return None - br = detect_rs_box_break(c, up, lo) - if br: - return br - return rs_break_infer_from_close(c, up, lo) - - -def notify_interval_elapsed( - last_notified_at: Optional[str], - interval_min: int, - now_dt: datetime, -) -> bool: - if not last_notified_at: - return False - last_dt = _parse_notify_datetime(last_notified_at) - if last_dt is None: - return False - return (now_dt - last_dt).total_seconds() >= max(1, int(interval_min)) * 60 - - -def format_auto_amp_line(amp_ok: bool, amp_pct: float, min_pct: float) -> str: - return ( - f"突破越过幅度:{'通过' if amp_ok else '不通过'}" - f"({round(float(amp_pct), 4)}%,要求 > {min_pct}%)" - ) - - -def format_auto_confirm_line(confirm_ok: bool, cfm_close, edge_price, direction: str) -> str: - side = "箱外上方" if (direction or "").lower() == "long" else "箱外下方" - return ( - f"第二根确认:{'通过' if confirm_ok else '不通过'}" - f"(确认收盘 {cfm_close},须收于{side},关键位 {edge_price})" - ) - - -def key_monitor_rule_template_context( - *, - kline_timeframe: str, - key_breakout_amp_min_pct: float, - key_volume_ma_bars: int, - key_volume_ratio_min: float, - key_auto_min_planned_rr: float, - key_daily_volume_rank_max: int, - key_confirm_breakout_bar: int, - key_confirm_bar: int, - key_alert_max_times: int, - key_alert_interval_minutes: int, - key_stop_outside_breakout_pct: float, - key_trend_stop_outside_pct: float, - false_breakout_validity_hours: int, - trigger_entry_validity_hours: int | None = None, -) -> dict[str, Any]: - """关键位监控页规则说明表格(Jinja key_rule_ctx)。""" - from false_breakout_key_monitor_lib import ( - FALSE_BREAKOUT_OFFSET_PCT, - FALSE_BREAKOUT_RR, - FALSE_BREAKOUT_SL_PCT, - ) - from trigger_entry_key_monitor_lib import TRIGGER_ENTRY_VALIDITY_HOURS - - te_hours = ( - int(trigger_entry_validity_hours) - if trigger_entry_validity_hours is not None - else TRIGGER_ENTRY_VALIDITY_HOURS - ) - - return { - "tf": (kline_timeframe or "5m").strip(), - "amp_min_pct": key_breakout_amp_min_pct, - "vol_ma_bars": key_volume_ma_bars, - "vol_ratio_min": key_volume_ratio_min, - "min_rr": key_auto_min_planned_rr, - "vol_rank_max": key_daily_volume_rank_max, - "breakout_bar": key_confirm_breakout_bar, - "confirm_bar": key_confirm_bar, - "alert_max": key_alert_max_times, - "alert_interval_min": key_alert_interval_minutes, - "stop_outside_pct": key_stop_outside_breakout_pct, - "trend_stop_outside_pct": key_trend_stop_outside_pct, - "fb_offset_pct": FALSE_BREAKOUT_OFFSET_PCT, - "fb_sl_pct": FALSE_BREAKOUT_SL_PCT, - "fb_rr": FALSE_BREAKOUT_RR, - "fb_valid_hours": false_breakout_validity_hours, - "trigger_entry_validity_hours": te_hours, - } +""" +关键位监控:阻力/支撑双向提醒与箱体/收敛自动门控的共享逻辑。 +""" +from __future__ import annotations + +from datetime import datetime +from typing import Any, Optional + +KEY_MONITOR_AUTO_TYPES = frozenset({"箱体突破", "收敛突破"}) +KEY_MONITOR_RS_TYPE = "关键支撑阻力" +KEY_MONITOR_RS_LEGACY_TYPES = frozenset({"关键阻力位", "关键支撑位"}) +KEY_MONITOR_RS_TYPES = frozenset({KEY_MONITOR_RS_TYPE}) | KEY_MONITOR_RS_LEGACY_TYPES +KEY_MONITOR_ALERT_ONLY_TYPES = frozenset({KEY_MONITOR_RS_TYPE}) | KEY_MONITOR_RS_LEGACY_TYPES +KEY_DIRECTION_WATCH = "watch" + + +def is_rs_key_monitor_type(monitor_type: str) -> bool: + return (monitor_type or "").strip() in KEY_MONITOR_RS_TYPES + + +def rs_monitor_type_label(monitor_type: str) -> str: + """展示用:旧库里的阻力/支撑合并为「关键支撑阻力」。""" + if is_rs_key_monitor_type(monitor_type): + return KEY_MONITOR_RS_TYPE + return (monitor_type or "").strip() + + +def rs_monitor_type_for_storage(monitor_type: str) -> str: + if is_rs_key_monitor_type(monitor_type): + return KEY_MONITOR_RS_TYPE + return (monitor_type or "").strip() + + +def calc_breakout_breach_pct(direction: str, close: float, upper: float, lower: float) -> float: + """突破 K 收盘相对关键位的越过幅度(%)。未越过对应边界时返回 0。""" + direction = (direction or "long").strip().lower() + c = float(close) + if direction == "long": + u = float(upper) + if u <= 0 or c <= u: + return 0.0 + return (c - u) / u * 100.0 + lo = float(lower) + if lo <= 0 or c >= lo: + return 0.0 + return (lo - c) / lo * 100.0 + + +def auto_amp_ok( + direction: str, + close_b: float, + upper: float, + lower: float, + min_pct: float, +) -> tuple[bool, float]: + breach = calc_breakout_breach_pct(direction, close_b, upper, lower) + return breach > float(min_pct), breach + + +def auto_confirm_ok(direction: str, cfm_close: float, upper: float, lower: float) -> bool: + """确认 K 收盘须在箱体外(不得回到 [lower, upper] 内)。""" + direction = (direction or "long").strip().lower() + c = float(cfm_close) + if direction == "long": + return c > float(upper) + return c < float(lower) + + +BOX_BREAKOUT_CLOSE_OPPOSITE = "box_opposite_break" + + +def box_breakout_invalidate_by_mark( + direction: str, mark_price: float, upper: float, lower: float +) -> bool: + """箱体/收敛:标记价先突破反向边界则失效。多:mark<=L;空:mark>=H。""" + try: + m = float(mark_price) + h = float(upper) + lo = float(lower) + except (TypeError, ValueError): + return False + direction = (direction or "long").strip().lower() + if direction == "short": + return m >= h + return m <= lo + + +def box_breakout_invalidate_edge_label(direction: str) -> str: + direction = (direction or "long").strip().lower() + return "下沿" if direction == "long" else "上沿" + + +def detect_rs_box_break(close: float, upper: float, lower: float) -> Optional[dict[str, Any]]: + """ + 阻力/支撑人工盯盘:最近 5m 收盘突破上沿或下沿(严格 > / <)。 + 上沿优先:同一根 K 不可能同时满足两者。 + """ + u, lo, c = float(upper), float(lower), float(close) + if c > u: + return { + "break_side": "upper", + "direction": "long", + "edge_price": u, + "key_price": u, + "break_label": "向上突破上沿", + } + if c < lo: + return { + "break_side": "lower", + "direction": "short", + "edge_price": lo, + "key_price": lo, + "break_label": "向下突破下沿", + } + return None + + +def rs_break_from_direction(direction: str, upper: float, lower: float) -> Optional[dict[str, Any]]: + """已触发后根据入库方向还原突破边(long=上沿,short=下沿)。""" + d = (direction or "").strip().lower() + if d == "long": + return { + "break_side": "upper", + "direction": "long", + "edge_price": float(upper), + "key_price": float(upper), + "break_label": "向上突破上沿", + } + if d == "short": + return { + "break_side": "lower", + "direction": "short", + "edge_price": float(lower), + "key_price": float(lower), + "break_label": "向下突破下沿", + } + return None + + +def rs_break_infer_from_close(close: float, upper: float, lower: float) -> dict[str, Any]: + """ + 续发提醒时价格已回到箱体内:按收盘价相对箱体中线推断首次突破边, + 保证第 2/3 次企业微信提醒仍能发出。 + """ + mid = (float(upper) + float(lower)) / 2.0 + if float(close) >= mid: + br = rs_break_from_direction("long", upper, lower) + else: + br = rs_break_from_direction("short", upper, lower) + if br: + return br + return { + "break_side": "upper", + "direction": "long", + "edge_price": float(upper), + "key_price": float(upper), + "break_label": "向上突破上沿", + } + + +def _parse_notify_datetime(raw: Optional[str]) -> Optional[datetime]: + s = str(raw or "").strip() + if not s: + return None + try: + dt = datetime.fromisoformat(s.replace("Z", "+00:00")) + if dt.tzinfo is not None: + dt = dt.replace(tzinfo=None) + return dt + except Exception: + pass + for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"): + try: + return datetime.strptime(s[:19], fmt) + except Exception: + continue + return None + + +def claim_rs_level_notify( + conn: Any, + monitor_id: int, + notify_index: int, + direction: str, + notified_at: str, + bar_ts: Optional[int], + *, + prior_count: Optional[int] = None, +) -> bool: + """ + 原子占位:仅在 notification_count 仍为 prior_count 时推进到 notify_index。 + 须在发送企业微信之前调用并 commit,避免 (2/3) 重复刷屏。 + """ + prior = int(prior_count if prior_count is not None else notify_index - 1) + if prior < 0 or notify_index != prior + 1: + return False + bar_val: Optional[int] = None + if bar_ts is not None: + try: + bar_val = int(bar_ts) + except (TypeError, ValueError): + bar_val = None + cur = conn.execute( + "UPDATE key_monitors SET notification_count=?, direction=?, last_notified_at=?, last_rs_bar_ts=? " + "WHERE id=? AND COALESCE(notification_count,0)=?", + (notify_index, direction, notified_at, bar_val, int(monitor_id), prior), + ) + return int(cur.rowcount or 0) > 0 + + +def parse_last_rs_bar_ts(row: Any) -> Optional[int]: + if row is None: + return None + try: + keys = row.keys() if hasattr(row, "keys") else [] + except Exception: + keys = [] + raw = row["last_rs_bar_ts"] if "last_rs_bar_ts" in keys else None + if raw is None: + return None + try: + return int(raw) + except (TypeError, ValueError): + return None + + +def run_rs_level_alert_tick( + row: Any, + close: float, + bar_ts: Optional[int], + now_dt: datetime, + *, + default_max_notify: int, + default_interval_min: int, +) -> Optional[dict[str, Any]]: + """ + 判定本轮回合是否应推送阻力/支撑提醒。 + 首条:仅在新闭合 K 越线时触发;发送前须 claim_rs_level_notify 占位防轮询/多进程重复。 + """ + up, lo = float(row["upper"]), float(row["lower"]) + if up <= lo: + return None + count = int(row["notification_count"] or 0) + max_n = max(1, int(row["max_notify"] or default_max_notify)) + interval = max(1, int(row["notify_interval_min"] or default_interval_min)) + if count >= max_n: + return None + + bar_ts_i: Optional[int] = None + if bar_ts is not None: + try: + bar_ts_i = int(bar_ts) + except (TypeError, ValueError): + bar_ts_i = None + last_bar_i = parse_last_rs_bar_ts(row) + + if count == 0: + br = detect_rs_box_break(close, up, lo) + if not br: + return None + if bar_ts_i is not None and last_bar_i is not None and bar_ts_i == last_bar_i: + return None + return { + "break_info": br, + "notify_index": 1, + "prior_count": 0, + "notify_max": max_n, + "interval_min": interval, + "bar_ts": bar_ts_i, + } + + if not notify_interval_elapsed(row["last_notified_at"], interval, now_dt): + return None + br = resolve_rs_break_for_alert(count, row["direction"], close, up, lo) + if not br: + return None + return { + "break_info": br, + "notify_index": count + 1, + "prior_count": count, + "notify_max": max_n, + "interval_min": interval, + "bar_ts": bar_ts_i, + } + + +def resolve_rs_break_for_alert( + notification_count: int, + direction: Optional[str], + close: float, + upper: float, + lower: float, +) -> Optional[dict[str, Any]]: + """ + 阻力/支撑提醒:首次用 5m 收盘越线判定;后续用已存方向,兼容 direction=watch。 + """ + count = int(notification_count or 0) + up, lo, c = float(upper), float(lower), float(close) + if count <= 0: + return detect_rs_box_break(c, up, lo) + br = rs_break_from_direction(direction, up, lo) + if br: + return br + d = (direction or "").strip().lower() + if d not in ("", KEY_DIRECTION_WATCH): + return None + br = detect_rs_box_break(c, up, lo) + if br: + return br + return rs_break_infer_from_close(c, up, lo) + + +def notify_interval_elapsed( + last_notified_at: Optional[str], + interval_min: int, + now_dt: datetime, +) -> bool: + if not last_notified_at: + return False + last_dt = _parse_notify_datetime(last_notified_at) + if last_dt is None: + return False + return (now_dt - last_dt).total_seconds() >= max(1, int(interval_min)) * 60 + + +def format_auto_amp_line(amp_ok: bool, amp_pct: float, min_pct: float) -> str: + return ( + f"突破越过幅度:{'通过' if amp_ok else '不通过'}" + f"({round(float(amp_pct), 4)}%,要求 > {min_pct}%)" + ) + + +def format_auto_confirm_line(confirm_ok: bool, cfm_close, edge_price, direction: str) -> str: + side = "箱外上方" if (direction or "").lower() == "long" else "箱外下方" + return ( + f"第二根确认:{'通过' if confirm_ok else '不通过'}" + f"(确认收盘 {cfm_close},须收于{side},关键位 {edge_price})" + ) + + +def key_monitor_rule_template_context( + *, + kline_timeframe: str, + key_breakout_amp_min_pct: float, + key_volume_ma_bars: int, + key_volume_ratio_min: float, + key_auto_min_planned_rr: float, + key_daily_volume_rank_max: int, + key_confirm_breakout_bar: int, + key_confirm_bar: int, + key_alert_max_times: int, + key_alert_interval_minutes: int, + key_stop_outside_breakout_pct: float, + key_trend_stop_outside_pct: float, + false_breakout_validity_hours: int, + trigger_entry_validity_hours: int | None = None, +) -> dict[str, Any]: + """关键位监控页规则说明表格(Jinja key_rule_ctx)。""" + from lib.key_monitor.false_breakout_key_monitor_lib import ( + FALSE_BREAKOUT_OFFSET_PCT, + FALSE_BREAKOUT_RR, + FALSE_BREAKOUT_SL_PCT, + ) + from lib.key_monitor.trigger_entry_key_monitor_lib import TRIGGER_ENTRY_VALIDITY_HOURS + + te_hours = ( + int(trigger_entry_validity_hours) + if trigger_entry_validity_hours is not None + else TRIGGER_ENTRY_VALIDITY_HOURS + ) + + return { + "tf": (kline_timeframe or "5m").strip(), + "amp_min_pct": key_breakout_amp_min_pct, + "vol_ma_bars": key_volume_ma_bars, + "vol_ratio_min": key_volume_ratio_min, + "min_rr": key_auto_min_planned_rr, + "vol_rank_max": key_daily_volume_rank_max, + "breakout_bar": key_confirm_breakout_bar, + "confirm_bar": key_confirm_bar, + "alert_max": key_alert_max_times, + "alert_interval_min": key_alert_interval_minutes, + "stop_outside_pct": key_stop_outside_breakout_pct, + "trend_stop_outside_pct": key_trend_stop_outside_pct, + "fb_offset_pct": FALSE_BREAKOUT_OFFSET_PCT, + "fb_sl_pct": FALSE_BREAKOUT_SL_PCT, + "fb_rr": FALSE_BREAKOUT_RR, + "fb_valid_hours": false_breakout_validity_hours, + "trigger_entry_validity_hours": te_hours, + } diff --git a/key_monitor_schema_lib.py b/lib/key_monitor/key_monitor_schema_lib.py similarity index 100% rename from key_monitor_schema_lib.py rename to lib/key_monitor/key_monitor_schema_lib.py diff --git a/key_sl_tp_lib.py b/lib/key_monitor/key_sl_tp_lib.py similarity index 100% rename from key_sl_tp_lib.py rename to lib/key_monitor/key_sl_tp_lib.py diff --git a/trigger_entry_key_monitor_lib.py b/lib/key_monitor/trigger_entry_key_monitor_lib.py similarity index 95% rename from trigger_entry_key_monitor_lib.py rename to lib/key_monitor/trigger_entry_key_monitor_lib.py index c98efa1..5ff31bf 100644 --- a/trigger_entry_key_monitor_lib.py +++ b/lib/key_monitor/trigger_entry_key_monitor_lib.py @@ -1,296 +1,296 @@ -"""回调/突破触价开仓关键位监控:程序盯价、触达计划入场后市价成交(四所共用逻辑)。""" -from __future__ import annotations - -from datetime import datetime -from typing import Any, Callable, Optional - -from false_breakout_key_monitor_lib import ( - _parse_created_at, - expires_at_text, - is_false_breakout_expired, -) -from strategy_trend_lib import trend_dca_level_reached - -# 回调触价(原「触价开仓」) -CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE = "回调触价开仓" -LEGACY_TRIGGER_ENTRY_MONITOR_TYPE = "触价开仓" - -# 突破触价:标记价穿越 E 后立即市价开仓 -BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE = "突破触价开仓" - -TRIGGER_ENTRY_MONITOR_TYPES = frozenset( - { - CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE, - BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE, - LEGACY_TRIGGER_ENTRY_MONITOR_TYPE, - } -) - -TRIGGER_ENTRY_VALIDITY_HOURS = 24 -TRIGGER_ENTRY_CLOSE_FILLED = "trigger_entry_filled" -TRIGGER_ENTRY_CLOSE_TP_INVALIDATE = "trigger_tp_invalidate" -TRIGGER_ENTRY_CLOSE_SL_INVALIDATE = "trigger_sl_invalidate" -TRIGGER_ENTRY_CLOSE_EXPIRED = "trigger_entry_expired" -TRIGGER_ENTRY_CLOSE_EXCHANGE_FAILED = "trigger_exchange_failed" - -KEY_ENTRY_REASON_CALLBACK = "关键位回调触价开仓" -KEY_ENTRY_REASON_BREAKOUT = "关键位突破触价开仓" -KEY_ENTRY_REASON_TRIGGER_LEGACY = "关键位触价开仓" - - -def normalize_trigger_entry_monitor_type(monitor_type: Optional[str]) -> str: - mt = (monitor_type or "").strip() - if mt == LEGACY_TRIGGER_ENTRY_MONITOR_TYPE: - return CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE - return mt - - -def is_trigger_entry_key_monitor_type(monitor_type: Optional[str]) -> bool: - return (monitor_type or "").strip() in TRIGGER_ENTRY_MONITOR_TYPES - - -def is_callback_trigger_entry_key_monitor_type(monitor_type: Optional[str]) -> bool: - mt = normalize_trigger_entry_monitor_type(monitor_type) - return mt == CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE - - -def is_breakout_trigger_entry_key_monitor_type(monitor_type: Optional[str]) -> bool: - return (monitor_type or "").strip() == BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE - - -def key_entry_reason_for_monitor_type(monitor_type: Optional[str]) -> str: - if is_breakout_trigger_entry_key_monitor_type(monitor_type): - return KEY_ENTRY_REASON_BREAKOUT - if is_trigger_entry_key_monitor_type(monitor_type): - return KEY_ENTRY_REASON_CALLBACK - return KEY_ENTRY_REASON_TRIGGER_LEGACY - - -def trigger_entry_reached(direction: str, mark_price: float, entry: float) -> bool: - """回调触价:多=价跌至 E;空=价涨至 E。""" - return trend_dca_level_reached(direction, mark_price, entry) - - -def breakout_trigger_entry_crossed( - direction: str, - prev_mark: Optional[float], - mark: float, - entry: float, -) -> bool: - """突破触价:多=向上穿越 E;空=向下穿越 E。""" - try: - m = float(mark) - e = float(entry) - pm = float(prev_mark) if prev_mark is not None else None - except (TypeError, ValueError): - return False - direction = (direction or "long").strip().lower() - if direction == "long": - if pm is None: - return m > e - return pm <= e and m > e - if pm is None: - return m < e - return pm >= e and m < e - - -def trigger_should_fire( - monitor_type: Optional[str], - direction: str, - mark: float, - entry: float, - prev_mark: Optional[float] = None, -) -> bool: - if is_breakout_trigger_entry_key_monitor_type(monitor_type): - return breakout_trigger_entry_crossed(direction, prev_mark, mark, entry) - return trigger_entry_reached(direction, mark, entry) - - -def trigger_entry_invalidate_by_tp(direction: str, mark_price: float, take_profit: float) -> bool: - """未开仓前标记价先触达止盈侧则失效。""" - try: - m = float(mark_price) - tp = float(take_profit) - except (TypeError, ValueError): - return False - d = (direction or "long").strip().lower() - if d == "short": - return m <= tp - return m >= tp - - -def trigger_entry_invalidate_by_sl(direction: str, mark_price: float, stop_loss: float) -> bool: - """突破触价:未到 E 先触达止损侧则失效。""" - try: - m = float(mark_price) - sl = float(stop_loss) - except (TypeError, ValueError): - return False - d = (direction or "long").strip().lower() - if d == "long": - return m <= sl - return m >= sl - - -def trigger_entry_invalidate( - monitor_type: Optional[str], - direction: str, - mark: float, - stop_loss: float, - take_profit: float, -) -> Optional[str]: - if trigger_entry_invalidate_by_tp(direction, mark, take_profit): - return "tp" - if is_breakout_trigger_entry_key_monitor_type(monitor_type): - if trigger_entry_invalidate_by_sl(direction, mark, stop_loss): - return "sl" - return None - - -def validate_trigger_entry_geometry( - direction: str, - entry: float, - stop_loss: float, - take_profit: float, - mark_at_add: Optional[float] = None, - *, - monitor_type: Optional[str] = None, -) -> Optional[str]: - """返回错误文案;合法则 None。""" - try: - e = float(entry) - sl = float(stop_loss) - tp = float(take_profit) - except (TypeError, ValueError): - return "入场价、止损、止盈须为有效数字" - if e <= 0 or sl <= 0 or tp <= 0: - return "入场价、止损、止盈须大于 0" - d = (direction or "long").strip().lower() - mt = normalize_trigger_entry_monitor_type(monitor_type) - label = "突破触价开仓" if mt == BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE else "回调触价开仓" - if d == "long": - if not (sl < e < tp): - return "做多:须满足 止损 < 入场价 < 止盈" - if mark_at_add is not None: - m = float(mark_at_add) - if m >= tp: - return f"做多:当前价已不低于止盈,无法添加{label}" - if mt == BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE and m >= e: - return "做多:当前价须低于入场价(等待向上突破)" - elif d == "short": - if not (tp < e < sl): - return "做空:须满足 止盈 < 入场价 < 止损" - if mark_at_add is not None: - m = float(mark_at_add) - if m <= tp: - return f"做空:当前价已不高于止盈,无法添加{label}" - if mt == BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE and m <= e: - return "做空:当前价须高于入场价(等待向下跌破)" - else: - return "方向须为 long 或 short" - return None - - -def validate_trigger_entry_rr( - direction: str, - entry: float, - stop_loss: float, - take_profit: float, - min_rr: float, - calc_rr_ratio: Callable[..., Optional[float]], -) -> Optional[str]: - rr = calc_rr_ratio(direction, entry, stop_loss, take_profit) - if rr is None or rr <= float(min_rr): - fmt = f"{rr:.4f}" if rr is not None else "无法计算" - return f"计划盈亏比 {fmt}:1 未达要求(>{float(min_rr)}:1)" - return None - - -def is_trigger_entry_expired( - created_at: Any, - now: datetime, - *, - hours: int = TRIGGER_ENTRY_VALIDITY_HOURS, -) -> bool: - return is_false_breakout_expired(created_at, now, hours=hours) - - -def trigger_entry_expires_at_text( - created_at: Any, - *, - hours: int = TRIGGER_ENTRY_VALIDITY_HOURS, -) -> str: - return expires_at_text(created_at, hours=hours) - - -def count_pending_trigger_entries(conn: Any, trading_day: str) -> int: - td = (trading_day or "").strip() - if not td: - return 0 - placeholders = ",".join("?" * len(TRIGGER_ENTRY_MONITOR_TYPES)) - row = conn.execute( - f"SELECT COUNT(*) FROM key_monitors WHERE monitor_type IN ({placeholders}) AND session_date=?", - (*TRIGGER_ENTRY_MONITOR_TYPES, td), - ).fetchone() - return int(row[0] if row else 0) - - -def check_trigger_entry_intent_limit( - conn: Any, - trading_day: str, - opens_today: int, - hard_limit: int, -) -> tuple[bool, str]: - """当日开仓意图:已成交次数 + 待触发触价条数。""" - if int(hard_limit) <= 0: - return True, "" - pending = count_pending_trigger_entries(conn, trading_day) - total = int(opens_today) + pending - if total >= int(hard_limit): - return ( - False, - f"本交易日开仓意图已达上限(已开 {int(opens_today)} + 待触发 {pending} / 硬上限 {int(hard_limit)})", - ) - return True, "" - - -def trigger_entry_gate_preview( - *, - monitor_type: Optional[str] = None, - entry_display: str, - take_profit_display: str, - created_at: Any = None, - now: Optional[datetime] = None, - expired: bool = False, - tp_invalidated: bool = False, - sl_invalidated: bool = False, - hours: int = TRIGGER_ENTRY_VALIDITY_HOURS, -) -> dict[str, Any]: - now_dt = now or datetime.now() - is_exp = expired or is_trigger_entry_expired(created_at, now_dt, hours=hours) - exp_txt = trigger_entry_expires_at_text(created_at, hours=hours) - mt = normalize_trigger_entry_monitor_type(monitor_type) - if tp_invalidated: - status = "止盈侧失效" - elif sl_invalidated: - status = "止损侧失效" - elif is_exp: - status = "已过期" - elif mt == BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE: - status = "突破待触发" - else: - status = "回调待触发" - mode = "突破" if mt == BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE else "回调" - metrics_parts: list[str] = [f"TP:{take_profit_display}"] - if exp_txt != "—": - metrics_parts.append(f"截至:{exp_txt}") - return { - "summary": f"{mode}触价 E={entry_display} {status}", - "metrics": " ".join(metrics_parts), - "gate_ok": not is_exp and not tp_invalidated and not sl_invalidated, - } - - -# 兼容旧 import -TRIGGER_ENTRY_MONITOR_TYPE = CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE -KEY_ENTRY_REASON_TRIGGER = KEY_ENTRY_REASON_CALLBACK +"""回调/突破触价开仓关键位监控:程序盯价、触达计划入场后市价成交(四所共用逻辑)。""" +from __future__ import annotations + +from datetime import datetime +from typing import Any, Callable, Optional + +from lib.key_monitor.false_breakout_key_monitor_lib import ( + _parse_created_at, + expires_at_text, + is_false_breakout_expired, +) +from lib.strategy.strategy_trend_lib import trend_dca_level_reached + +# 回调触价(原「触价开仓」) +CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE = "回调触价开仓" +LEGACY_TRIGGER_ENTRY_MONITOR_TYPE = "触价开仓" + +# 突破触价:标记价穿越 E 后立即市价开仓 +BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE = "突破触价开仓" + +TRIGGER_ENTRY_MONITOR_TYPES = frozenset( + { + CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE, + BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE, + LEGACY_TRIGGER_ENTRY_MONITOR_TYPE, + } +) + +TRIGGER_ENTRY_VALIDITY_HOURS = 24 +TRIGGER_ENTRY_CLOSE_FILLED = "trigger_entry_filled" +TRIGGER_ENTRY_CLOSE_TP_INVALIDATE = "trigger_tp_invalidate" +TRIGGER_ENTRY_CLOSE_SL_INVALIDATE = "trigger_sl_invalidate" +TRIGGER_ENTRY_CLOSE_EXPIRED = "trigger_entry_expired" +TRIGGER_ENTRY_CLOSE_EXCHANGE_FAILED = "trigger_exchange_failed" + +KEY_ENTRY_REASON_CALLBACK = "关键位回调触价开仓" +KEY_ENTRY_REASON_BREAKOUT = "关键位突破触价开仓" +KEY_ENTRY_REASON_TRIGGER_LEGACY = "关键位触价开仓" + + +def normalize_trigger_entry_monitor_type(monitor_type: Optional[str]) -> str: + mt = (monitor_type or "").strip() + if mt == LEGACY_TRIGGER_ENTRY_MONITOR_TYPE: + return CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE + return mt + + +def is_trigger_entry_key_monitor_type(monitor_type: Optional[str]) -> bool: + return (monitor_type or "").strip() in TRIGGER_ENTRY_MONITOR_TYPES + + +def is_callback_trigger_entry_key_monitor_type(monitor_type: Optional[str]) -> bool: + mt = normalize_trigger_entry_monitor_type(monitor_type) + return mt == CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE + + +def is_breakout_trigger_entry_key_monitor_type(monitor_type: Optional[str]) -> bool: + return (monitor_type or "").strip() == BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE + + +def key_entry_reason_for_monitor_type(monitor_type: Optional[str]) -> str: + if is_breakout_trigger_entry_key_monitor_type(monitor_type): + return KEY_ENTRY_REASON_BREAKOUT + if is_trigger_entry_key_monitor_type(monitor_type): + return KEY_ENTRY_REASON_CALLBACK + return KEY_ENTRY_REASON_TRIGGER_LEGACY + + +def trigger_entry_reached(direction: str, mark_price: float, entry: float) -> bool: + """回调触价:多=价跌至 E;空=价涨至 E。""" + return trend_dca_level_reached(direction, mark_price, entry) + + +def breakout_trigger_entry_crossed( + direction: str, + prev_mark: Optional[float], + mark: float, + entry: float, +) -> bool: + """突破触价:多=向上穿越 E;空=向下穿越 E。""" + try: + m = float(mark) + e = float(entry) + pm = float(prev_mark) if prev_mark is not None else None + except (TypeError, ValueError): + return False + direction = (direction or "long").strip().lower() + if direction == "long": + if pm is None: + return m > e + return pm <= e and m > e + if pm is None: + return m < e + return pm >= e and m < e + + +def trigger_should_fire( + monitor_type: Optional[str], + direction: str, + mark: float, + entry: float, + prev_mark: Optional[float] = None, +) -> bool: + if is_breakout_trigger_entry_key_monitor_type(monitor_type): + return breakout_trigger_entry_crossed(direction, prev_mark, mark, entry) + return trigger_entry_reached(direction, mark, entry) + + +def trigger_entry_invalidate_by_tp(direction: str, mark_price: float, take_profit: float) -> bool: + """未开仓前标记价先触达止盈侧则失效。""" + try: + m = float(mark_price) + tp = float(take_profit) + except (TypeError, ValueError): + return False + d = (direction or "long").strip().lower() + if d == "short": + return m <= tp + return m >= tp + + +def trigger_entry_invalidate_by_sl(direction: str, mark_price: float, stop_loss: float) -> bool: + """突破触价:未到 E 先触达止损侧则失效。""" + try: + m = float(mark_price) + sl = float(stop_loss) + except (TypeError, ValueError): + return False + d = (direction or "long").strip().lower() + if d == "long": + return m <= sl + return m >= sl + + +def trigger_entry_invalidate( + monitor_type: Optional[str], + direction: str, + mark: float, + stop_loss: float, + take_profit: float, +) -> Optional[str]: + if trigger_entry_invalidate_by_tp(direction, mark, take_profit): + return "tp" + if is_breakout_trigger_entry_key_monitor_type(monitor_type): + if trigger_entry_invalidate_by_sl(direction, mark, stop_loss): + return "sl" + return None + + +def validate_trigger_entry_geometry( + direction: str, + entry: float, + stop_loss: float, + take_profit: float, + mark_at_add: Optional[float] = None, + *, + monitor_type: Optional[str] = None, +) -> Optional[str]: + """返回错误文案;合法则 None。""" + try: + e = float(entry) + sl = float(stop_loss) + tp = float(take_profit) + except (TypeError, ValueError): + return "入场价、止损、止盈须为有效数字" + if e <= 0 or sl <= 0 or tp <= 0: + return "入场价、止损、止盈须大于 0" + d = (direction or "long").strip().lower() + mt = normalize_trigger_entry_monitor_type(monitor_type) + label = "突破触价开仓" if mt == BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE else "回调触价开仓" + if d == "long": + if not (sl < e < tp): + return "做多:须满足 止损 < 入场价 < 止盈" + if mark_at_add is not None: + m = float(mark_at_add) + if m >= tp: + return f"做多:当前价已不低于止盈,无法添加{label}" + if mt == BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE and m >= e: + return "做多:当前价须低于入场价(等待向上突破)" + elif d == "short": + if not (tp < e < sl): + return "做空:须满足 止盈 < 入场价 < 止损" + if mark_at_add is not None: + m = float(mark_at_add) + if m <= tp: + return f"做空:当前价已不高于止盈,无法添加{label}" + if mt == BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE and m <= e: + return "做空:当前价须高于入场价(等待向下跌破)" + else: + return "方向须为 long 或 short" + return None + + +def validate_trigger_entry_rr( + direction: str, + entry: float, + stop_loss: float, + take_profit: float, + min_rr: float, + calc_rr_ratio: Callable[..., Optional[float]], +) -> Optional[str]: + rr = calc_rr_ratio(direction, entry, stop_loss, take_profit) + if rr is None or rr <= float(min_rr): + fmt = f"{rr:.4f}" if rr is not None else "无法计算" + return f"计划盈亏比 {fmt}:1 未达要求(>{float(min_rr)}:1)" + return None + + +def is_trigger_entry_expired( + created_at: Any, + now: datetime, + *, + hours: int = TRIGGER_ENTRY_VALIDITY_HOURS, +) -> bool: + return is_false_breakout_expired(created_at, now, hours=hours) + + +def trigger_entry_expires_at_text( + created_at: Any, + *, + hours: int = TRIGGER_ENTRY_VALIDITY_HOURS, +) -> str: + return expires_at_text(created_at, hours=hours) + + +def count_pending_trigger_entries(conn: Any, trading_day: str) -> int: + td = (trading_day or "").strip() + if not td: + return 0 + placeholders = ",".join("?" * len(TRIGGER_ENTRY_MONITOR_TYPES)) + row = conn.execute( + f"SELECT COUNT(*) FROM key_monitors WHERE monitor_type IN ({placeholders}) AND session_date=?", + (*TRIGGER_ENTRY_MONITOR_TYPES, td), + ).fetchone() + return int(row[0] if row else 0) + + +def check_trigger_entry_intent_limit( + conn: Any, + trading_day: str, + opens_today: int, + hard_limit: int, +) -> tuple[bool, str]: + """当日开仓意图:已成交次数 + 待触发触价条数。""" + if int(hard_limit) <= 0: + return True, "" + pending = count_pending_trigger_entries(conn, trading_day) + total = int(opens_today) + pending + if total >= int(hard_limit): + return ( + False, + f"本交易日开仓意图已达上限(已开 {int(opens_today)} + 待触发 {pending} / 硬上限 {int(hard_limit)})", + ) + return True, "" + + +def trigger_entry_gate_preview( + *, + monitor_type: Optional[str] = None, + entry_display: str, + take_profit_display: str, + created_at: Any = None, + now: Optional[datetime] = None, + expired: bool = False, + tp_invalidated: bool = False, + sl_invalidated: bool = False, + hours: int = TRIGGER_ENTRY_VALIDITY_HOURS, +) -> dict[str, Any]: + now_dt = now or datetime.now() + is_exp = expired or is_trigger_entry_expired(created_at, now_dt, hours=hours) + exp_txt = trigger_entry_expires_at_text(created_at, hours=hours) + mt = normalize_trigger_entry_monitor_type(monitor_type) + if tp_invalidated: + status = "止盈侧失效" + elif sl_invalidated: + status = "止损侧失效" + elif is_exp: + status = "已过期" + elif mt == BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE: + status = "突破待触发" + else: + status = "回调待触发" + mode = "突破" if mt == BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE else "回调" + metrics_parts: list[str] = [f"TP:{take_profit_display}"] + if exp_txt != "—": + metrics_parts.append(f"截至:{exp_txt}") + return { + "summary": f"{mode}触价 E={entry_display} {status}", + "metrics": " ".join(metrics_parts), + "gate_ok": not is_exp and not tp_invalidated and not sl_invalidated, + } + + +# 兼容旧 import +TRIGGER_ENTRY_MONITOR_TYPE = CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE +KEY_ENTRY_REASON_TRIGGER = KEY_ENTRY_REASON_CALLBACK diff --git a/lib/paths.py b/lib/paths.py new file mode 100644 index 0000000..4493d57 --- /dev/null +++ b/lib/paths.py @@ -0,0 +1,22 @@ +"""Repository path helpers for lib/ assets.""" +from __future__ import annotations + +from pathlib import Path + +LIB_DIR = Path(__file__).resolve().parent +REPO_ROOT = LIB_DIR.parent + + +def strategy_templates_dir(repo_root: str | Path | None = None) -> str: + root = Path(repo_root) if repo_root is not None else REPO_ROOT + return str(root / "lib" / "strategy" / "templates") + + +def embed_templates_dir(repo_root: str | Path | None = None) -> str: + root = Path(repo_root) if repo_root is not None else REPO_ROOT + return str(root / "lib" / "instance" / "templates") + + +def common_static_dir(repo_root: str | Path | None = None) -> str: + root = Path(repo_root) if repo_root is not None else REPO_ROOT + return str(root / "lib" / "common" / "static") diff --git a/lib/strategy/__init__.py b/lib/strategy/__init__.py new file mode 100644 index 0000000..ab164b5 --- /dev/null +++ b/lib/strategy/__init__.py @@ -0,0 +1 @@ +"""Shared library package.""" diff --git a/strategy_config.py b/lib/strategy/strategy_config.py similarity index 100% rename from strategy_config.py rename to lib/strategy/strategy_config.py diff --git a/strategy_db.py b/lib/strategy/strategy_db.py similarity index 95% rename from strategy_db.py rename to lib/strategy/strategy_db.py index 19d3121..73a621c 100644 --- a/strategy_db.py +++ b/lib/strategy/strategy_db.py @@ -1,164 +1,164 @@ -"""策略交易相关表结构(各所 crypto.db 共用 schema)。""" - -ROLL_GROUPS_SQL = """ -CREATE TABLE IF NOT EXISTS roll_groups ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - order_monitor_id INTEGER, - symbol TEXT NOT NULL, - exchange_symbol TEXT, - direction TEXT NOT NULL, - initial_take_profit REAL, - initial_stop_loss REAL, - current_stop_loss REAL, - risk_percent REAL DEFAULT 2, - leg_count INTEGER DEFAULT 0, - status TEXT DEFAULT 'active', - created_at TEXT, - updated_at TEXT -) -""" - -ROLL_LEGS_SQL = """ -CREATE TABLE IF NOT EXISTS roll_legs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - roll_group_id INTEGER NOT NULL, - leg_index INTEGER NOT NULL, - add_mode TEXT NOT NULL, - fib_upper REAL, - fib_lower REAL, - limit_price REAL, - fill_price REAL, - amount REAL, - new_stop_loss REAL, - exchange_order_id TEXT, - status TEXT DEFAULT 'filled', - created_at TEXT, - FOREIGN KEY (roll_group_id) REFERENCES roll_groups(id) -) -""" - -TREND_PLANS_SQL = """ -CREATE TABLE IF NOT EXISTS trend_pullback_plans ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - status TEXT DEFAULT 'active', - symbol TEXT NOT NULL, - exchange_symbol TEXT, - direction TEXT NOT NULL DEFAULT 'long', - leverage INTEGER NOT NULL, - stop_loss REAL NOT NULL, - add_upper REAL NOT NULL, - take_profit REAL NOT NULL, - risk_percent REAL DEFAULT 5, - snapshot_available_usdt REAL, - snapshot_at TEXT, - plan_margin_capital REAL, - target_order_amount REAL, - first_order_amount REAL, - remainder_total REAL, - dca_legs INTEGER DEFAULT 5, - per_leg_amount REAL, - grid_prices_json TEXT, - leg_amounts_json TEXT, - legs_done INTEGER DEFAULT 0, - first_order_done INTEGER DEFAULT 0, - last_mark_price REAL, - avg_entry_price REAL, - order_amount_open REAL, - opened_at TEXT, - opened_at_ms INTEGER, - session_date TEXT, - message TEXT, - initial_stop_loss REAL, - breakeven_applied INTEGER DEFAULT 0, - breakeven_applied_at TEXT -) -""" - -TREND_PREVIEWS_SQL = """ -CREATE TABLE IF NOT EXISTS trend_pullback_previews ( - id TEXT PRIMARY KEY, - symbol TEXT NOT NULL, - exchange_symbol TEXT NOT NULL, - direction TEXT NOT NULL, - leverage INTEGER NOT NULL, - stop_loss REAL NOT NULL, - add_upper REAL NOT NULL, - take_profit REAL NOT NULL, - risk_percent REAL NOT NULL, - snapshot_available_usdt REAL NOT NULL, - snapshot_at TEXT, - live_price_ref REAL, - plan_margin_capital REAL, - target_order_amount REAL, - first_order_amount REAL, - remainder_total REAL, - dca_legs INTEGER, - per_leg_amount REAL, - grid_prices_json TEXT, - leg_amounts_json TEXT, - expires_at_ms INTEGER NOT NULL, - created_at TEXT -) -""" - -TREND_PREVIEW_SNAPSHOTS_SQL = """ -CREATE TABLE IF NOT EXISTS trend_pullback_preview_snapshots ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - preview_id TEXT NOT NULL UNIQUE, - symbol TEXT NOT NULL, - exchange_symbol TEXT NOT NULL, - direction TEXT NOT NULL, - leverage INTEGER NOT NULL, - stop_loss REAL NOT NULL, - add_upper REAL NOT NULL, - take_profit REAL NOT NULL, - risk_percent REAL NOT NULL, - snapshot_available_usdt REAL NOT NULL, - snapshot_at TEXT, - live_price_ref REAL, - plan_margin_capital REAL, - target_order_amount REAL, - first_order_amount REAL, - remainder_total REAL, - dca_legs INTEGER, - per_leg_amount REAL, - grid_prices_json TEXT, - leg_amounts_json TEXT, - expires_at_ms INTEGER NOT NULL, - preview_created_at TEXT, - outcome TEXT DEFAULT 'open', - executed_plan_id INTEGER -) -""" - - -def init_strategy_tables(conn) -> None: - from strategy_snapshot_lib import init_strategy_snapshot_table - - conn.execute(ROLL_GROUPS_SQL) - conn.execute(ROLL_LEGS_SQL) - conn.execute(TREND_PLANS_SQL) - conn.execute(TREND_PREVIEWS_SQL) - conn.execute(TREND_PREVIEW_SNAPSHOTS_SQL) - init_strategy_snapshot_table(conn) - for ddl in ( - "ALTER TABLE trend_pullback_plans ADD COLUMN leg_amounts_json TEXT", - "ALTER TABLE trend_pullback_plans ADD COLUMN initial_stop_loss REAL", - "ALTER TABLE trend_pullback_plans ADD COLUMN breakeven_applied INTEGER DEFAULT 0", - "ALTER TABLE trend_pullback_plans ADD COLUMN breakeven_applied_at TEXT", - "ALTER TABLE trend_pullback_preview_snapshots ADD COLUMN preview_created_at TEXT", - "ALTER TABLE trend_pullback_preview_snapshots ADD COLUMN outcome TEXT DEFAULT 'open'", - "ALTER TABLE trend_pullback_preview_snapshots ADD COLUMN executed_plan_id INTEGER", - "ALTER TABLE trade_records ADD COLUMN trend_plan_id INTEGER", - "ALTER TABLE order_monitors ADD COLUMN trend_plan_id INTEGER", - "ALTER TABLE order_monitors ADD COLUMN monitor_type TEXT", - "ALTER TABLE order_monitors ADD COLUMN key_signal_type TEXT", - "ALTER TABLE trend_pullback_plans ADD COLUMN leg_fill_prices_json TEXT", - "ALTER TABLE roll_legs ADD COLUMN stop_offset_pct REAL", - "ALTER TABLE roll_legs ADD COLUMN breakthrough_price REAL", - "ALTER TABLE roll_legs ADD COLUMN last_mark_price REAL", - ): - try: - conn.execute(ddl) - except Exception: - pass +"""策略交易相关表结构(各所 crypto.db 共用 schema)。""" + +ROLL_GROUPS_SQL = """ +CREATE TABLE IF NOT EXISTS roll_groups ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + order_monitor_id INTEGER, + symbol TEXT NOT NULL, + exchange_symbol TEXT, + direction TEXT NOT NULL, + initial_take_profit REAL, + initial_stop_loss REAL, + current_stop_loss REAL, + risk_percent REAL DEFAULT 2, + leg_count INTEGER DEFAULT 0, + status TEXT DEFAULT 'active', + created_at TEXT, + updated_at TEXT +) +""" + +ROLL_LEGS_SQL = """ +CREATE TABLE IF NOT EXISTS roll_legs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + roll_group_id INTEGER NOT NULL, + leg_index INTEGER NOT NULL, + add_mode TEXT NOT NULL, + fib_upper REAL, + fib_lower REAL, + limit_price REAL, + fill_price REAL, + amount REAL, + new_stop_loss REAL, + exchange_order_id TEXT, + status TEXT DEFAULT 'filled', + created_at TEXT, + FOREIGN KEY (roll_group_id) REFERENCES roll_groups(id) +) +""" + +TREND_PLANS_SQL = """ +CREATE TABLE IF NOT EXISTS trend_pullback_plans ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + status TEXT DEFAULT 'active', + symbol TEXT NOT NULL, + exchange_symbol TEXT, + direction TEXT NOT NULL DEFAULT 'long', + leverage INTEGER NOT NULL, + stop_loss REAL NOT NULL, + add_upper REAL NOT NULL, + take_profit REAL NOT NULL, + risk_percent REAL DEFAULT 5, + snapshot_available_usdt REAL, + snapshot_at TEXT, + plan_margin_capital REAL, + target_order_amount REAL, + first_order_amount REAL, + remainder_total REAL, + dca_legs INTEGER DEFAULT 5, + per_leg_amount REAL, + grid_prices_json TEXT, + leg_amounts_json TEXT, + legs_done INTEGER DEFAULT 0, + first_order_done INTEGER DEFAULT 0, + last_mark_price REAL, + avg_entry_price REAL, + order_amount_open REAL, + opened_at TEXT, + opened_at_ms INTEGER, + session_date TEXT, + message TEXT, + initial_stop_loss REAL, + breakeven_applied INTEGER DEFAULT 0, + breakeven_applied_at TEXT +) +""" + +TREND_PREVIEWS_SQL = """ +CREATE TABLE IF NOT EXISTS trend_pullback_previews ( + id TEXT PRIMARY KEY, + symbol TEXT NOT NULL, + exchange_symbol TEXT NOT NULL, + direction TEXT NOT NULL, + leverage INTEGER NOT NULL, + stop_loss REAL NOT NULL, + add_upper REAL NOT NULL, + take_profit REAL NOT NULL, + risk_percent REAL NOT NULL, + snapshot_available_usdt REAL NOT NULL, + snapshot_at TEXT, + live_price_ref REAL, + plan_margin_capital REAL, + target_order_amount REAL, + first_order_amount REAL, + remainder_total REAL, + dca_legs INTEGER, + per_leg_amount REAL, + grid_prices_json TEXT, + leg_amounts_json TEXT, + expires_at_ms INTEGER NOT NULL, + created_at TEXT +) +""" + +TREND_PREVIEW_SNAPSHOTS_SQL = """ +CREATE TABLE IF NOT EXISTS trend_pullback_preview_snapshots ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + preview_id TEXT NOT NULL UNIQUE, + symbol TEXT NOT NULL, + exchange_symbol TEXT NOT NULL, + direction TEXT NOT NULL, + leverage INTEGER NOT NULL, + stop_loss REAL NOT NULL, + add_upper REAL NOT NULL, + take_profit REAL NOT NULL, + risk_percent REAL NOT NULL, + snapshot_available_usdt REAL NOT NULL, + snapshot_at TEXT, + live_price_ref REAL, + plan_margin_capital REAL, + target_order_amount REAL, + first_order_amount REAL, + remainder_total REAL, + dca_legs INTEGER, + per_leg_amount REAL, + grid_prices_json TEXT, + leg_amounts_json TEXT, + expires_at_ms INTEGER NOT NULL, + preview_created_at TEXT, + outcome TEXT DEFAULT 'open', + executed_plan_id INTEGER +) +""" + + +def init_strategy_tables(conn) -> None: + from lib.strategy.strategy_snapshot_lib import init_strategy_snapshot_table + + conn.execute(ROLL_GROUPS_SQL) + conn.execute(ROLL_LEGS_SQL) + conn.execute(TREND_PLANS_SQL) + conn.execute(TREND_PREVIEWS_SQL) + conn.execute(TREND_PREVIEW_SNAPSHOTS_SQL) + init_strategy_snapshot_table(conn) + for ddl in ( + "ALTER TABLE trend_pullback_plans ADD COLUMN leg_amounts_json TEXT", + "ALTER TABLE trend_pullback_plans ADD COLUMN initial_stop_loss REAL", + "ALTER TABLE trend_pullback_plans ADD COLUMN breakeven_applied INTEGER DEFAULT 0", + "ALTER TABLE trend_pullback_plans ADD COLUMN breakeven_applied_at TEXT", + "ALTER TABLE trend_pullback_preview_snapshots ADD COLUMN preview_created_at TEXT", + "ALTER TABLE trend_pullback_preview_snapshots ADD COLUMN outcome TEXT DEFAULT 'open'", + "ALTER TABLE trend_pullback_preview_snapshots ADD COLUMN executed_plan_id INTEGER", + "ALTER TABLE trade_records ADD COLUMN trend_plan_id INTEGER", + "ALTER TABLE order_monitors ADD COLUMN trend_plan_id INTEGER", + "ALTER TABLE order_monitors ADD COLUMN monitor_type TEXT", + "ALTER TABLE order_monitors ADD COLUMN key_signal_type TEXT", + "ALTER TABLE trend_pullback_plans ADD COLUMN leg_fill_prices_json TEXT", + "ALTER TABLE roll_legs ADD COLUMN stop_offset_pct REAL", + "ALTER TABLE roll_legs ADD COLUMN breakthrough_price REAL", + "ALTER TABLE roll_legs ADD COLUMN last_mark_price REAL", + ): + try: + conn.execute(ddl) + except Exception: + pass diff --git a/strategy_exchange_base.py b/lib/strategy/strategy_exchange_base.py similarity index 100% rename from strategy_exchange_base.py rename to lib/strategy/strategy_exchange_base.py diff --git a/strategy_exchange_binance.py b/lib/strategy/strategy_exchange_binance.py similarity index 66% rename from strategy_exchange_binance.py rename to lib/strategy/strategy_exchange_binance.py index 1ff115f..7d21d01 100644 --- a/strategy_exchange_binance.py +++ b/lib/strategy/strategy_exchange_binance.py @@ -1,4 +1,4 @@ -"""Binance USDT-M 永续 — 策略交易交易所适配(见 strategy_config.build_strategy_config)。""" -from strategy_exchange_base import StrategyExchangeAdapter - -__all__ = ["StrategyExchangeAdapter"] +"""Binance USDT-M 永续 — 策略交易交易所适配(见 strategy_config.build_strategy_config)。""" +from lib.strategy.strategy_exchange_base import StrategyExchangeAdapter + +__all__ = ["StrategyExchangeAdapter"] diff --git a/strategy_exchange_gate.py b/lib/strategy/strategy_exchange_gate.py similarity index 80% rename from strategy_exchange_gate.py rename to lib/strategy/strategy_exchange_gate.py index 8fddaf0..3ac73f4 100644 --- a/strategy_exchange_gate.py +++ b/lib/strategy/strategy_exchange_gate.py @@ -1,9 +1,9 @@ -""" -Gate.io USDT 永续 — 策略交易交易所侧能力。 - -实现方式:各 Gate 实例 app 通过 strategy_config.build_strategy_config(app_module) 注入 -ccxt 下单、精度、换 TP/SL;本文件为文档与类型锚点,避免在四个 app 重复实现滚仓公式。 -""" -from strategy_exchange_base import StrategyExchangeAdapter - -__all__ = ["StrategyExchangeAdapter"] +""" +Gate.io USDT 永续 — 策略交易交易所侧能力。 + +实现方式:各 Gate 实例 app 通过 strategy_config.build_strategy_config(app_module) 注入 +ccxt 下单、精度、换 TP/SL;本文件为文档与类型锚点,避免在四个 app 重复实现滚仓公式。 +""" +from lib.strategy.strategy_exchange_base import StrategyExchangeAdapter + +__all__ = ["StrategyExchangeAdapter"] diff --git a/strategy_exchange_okx.py b/lib/strategy/strategy_exchange_okx.py similarity index 64% rename from strategy_exchange_okx.py rename to lib/strategy/strategy_exchange_okx.py index de54fc0..cdc6930 100644 --- a/strategy_exchange_okx.py +++ b/lib/strategy/strategy_exchange_okx.py @@ -1,4 +1,4 @@ -"""OKX 永续 — 策略交易交易所适配(见 strategy_config.build_strategy_config)。""" -from strategy_exchange_base import StrategyExchangeAdapter - -__all__ = ["StrategyExchangeAdapter"] +"""OKX 永续 — 策略交易交易所适配(见 strategy_config.build_strategy_config)。""" +from lib.strategy.strategy_exchange_base import StrategyExchangeAdapter + +__all__ = ["StrategyExchangeAdapter"] diff --git a/strategy_records_register.py b/lib/strategy/strategy_records_register.py similarity index 94% rename from strategy_records_register.py rename to lib/strategy/strategy_records_register.py index 5c1a604..1b1645d 100644 --- a/strategy_records_register.py +++ b/lib/strategy/strategy_records_register.py @@ -1,72 +1,72 @@ -"""策略交易记录页:已结束趋势 / 顺势加仓快照(四所统一)。""" -from __future__ import annotations - -import json -from typing import Any - -from flask import flash, redirect, url_for - -from strategy_snapshot_lib import ( - STRATEGY_SNAPSHOTS_MAX_ROWS, - dedupe_strategy_snapshots, - list_strategy_snapshots_split, -) - - -def load_strategy_records_page( - conn, *, limit: int = STRATEGY_SNAPSHOTS_MAX_ROWS -) -> dict[str, Any]: - try: - if dedupe_strategy_snapshots(conn): - conn.commit() - except Exception: - pass - trend, roll, symbols = list_strategy_snapshots_split(conn, limit=limit) - return { - "strategy_trend_records": trend, - "strategy_roll_records": roll, - "strategy_record_symbols": symbols, - "strategy_records_limit": limit, - "strategy_snapshots": trend + roll, - } - - -def register_strategy_records(app, cfg: dict[str, Any]) -> None: - login_required = cfg["login_required"] - get_db = cfg["get_db"] - - def _lr(f): - return login_required(f) - - @_lr - @app.route("/strategy/records") - def strategy_records_page(): - m = cfg.get("app_module") - fn = getattr(m, "render_main_page", None) - if not callable(fn): - flash("render_main_page 未配置") - return redirect(url_for("strategy_trading_page")) - return fn("strategy_records") - - @_lr - @app.route("/strategy/records/") - def strategy_records_detail(snap_id: int): - conn = get_db() - row = conn.execute( - "SELECT * FROM strategy_trade_snapshots WHERE id=?", - (int(snap_id),), - ).fetchone() - conn.close() - if not row: - flash("未找到该策略快照") - return redirect(url_for("strategy_records_page")) - try: - snap = json.loads(row["snapshot_json"] or "{}") - except Exception: - snap = {} - dca = snap.get("dca_levels") or [] - flash( - f"快照 #{snap_id} {row['strategy_type']} {row['symbol']} " - f"{row['result_label']} · 补仓档 {len(dca)} 项(详情见列表页)" - ) - return redirect(url_for("strategy_records_page")) +"""策略交易记录页:已结束趋势 / 顺势加仓快照(四所统一)。""" +from __future__ import annotations + +import json +from typing import Any + +from flask import flash, redirect, url_for + +from lib.strategy.strategy_snapshot_lib import ( + STRATEGY_SNAPSHOTS_MAX_ROWS, + dedupe_strategy_snapshots, + list_strategy_snapshots_split, +) + + +def load_strategy_records_page( + conn, *, limit: int = STRATEGY_SNAPSHOTS_MAX_ROWS +) -> dict[str, Any]: + try: + if dedupe_strategy_snapshots(conn): + conn.commit() + except Exception: + pass + trend, roll, symbols = list_strategy_snapshots_split(conn, limit=limit) + return { + "strategy_trend_records": trend, + "strategy_roll_records": roll, + "strategy_record_symbols": symbols, + "strategy_records_limit": limit, + "strategy_snapshots": trend + roll, + } + + +def register_strategy_records(app, cfg: dict[str, Any]) -> None: + login_required = cfg["login_required"] + get_db = cfg["get_db"] + + def _lr(f): + return login_required(f) + + @_lr + @app.route("/strategy/records") + def strategy_records_page(): + m = cfg.get("app_module") + fn = getattr(m, "render_main_page", None) + if not callable(fn): + flash("render_main_page 未配置") + return redirect(url_for("strategy_trading_page")) + return fn("strategy_records") + + @_lr + @app.route("/strategy/records/") + def strategy_records_detail(snap_id: int): + conn = get_db() + row = conn.execute( + "SELECT * FROM strategy_trade_snapshots WHERE id=?", + (int(snap_id),), + ).fetchone() + conn.close() + if not row: + flash("未找到该策略快照") + return redirect(url_for("strategy_records_page")) + try: + snap = json.loads(row["snapshot_json"] or "{}") + except Exception: + snap = {} + dca = snap.get("dca_levels") or [] + flash( + f"快照 #{snap_id} {row['strategy_type']} {row['symbol']} " + f"{row['result_label']} · 补仓档 {len(dca)} 项(详情见列表页)" + ) + return redirect(url_for("strategy_records_page")) diff --git a/strategy_register.py b/lib/strategy/strategy_register.py similarity index 94% rename from strategy_register.py rename to lib/strategy/strategy_register.py index cf7f40c..d558362 100644 --- a/strategy_register.py +++ b/lib/strategy/strategy_register.py @@ -1,619 +1,621 @@ -"""策略交易:Flask 路由注册(顺势加仓 + 趋势回调页)。逻辑在 strategy_*_lib。""" -from __future__ import annotations - -import html as html_module -import os -import re -from typing import Any, Optional - -from flask import Flask, flash, jsonify, redirect, render_template, request, url_for -from jinja2 import ChoiceLoader, FileSystemLoader - -from strategy_db import init_strategy_tables -from strategy_roll_lib import BREAKOUT_MODE, FIB_MODES, MARKET_MODE, preview_roll -from strategy_roll_monitor_lib import ( - cancel_roll_pending_leg, - count_filled_roll_legs, - count_pending_roll_legs, - sync_roll_after_external_close, -) - - -def _dedupe_strategy_snapshots_on_startup(cfg: dict[str, Any]) -> None: - """启动时清理历史重复快照(同计划同结果仅保留最新一条)。""" - get_db = cfg.get("get_db") - if not callable(get_db): - return - try: - from strategy_snapshot_lib import dedupe_strategy_snapshots - - conn = get_db() - try: - removed = dedupe_strategy_snapshots(conn) - if removed: - conn.commit() - print( - f"[strategy] deduped {removed} duplicate strategy_trade_snapshots", - flush=True, - ) - finally: - conn.close() - except Exception as e: - print(f"[strategy] snapshot dedupe skipped: {e}", flush=True) - - -def install_strategy_trading(app: Flask, repo_root: str, app_module: Any = None, **build_kw) -> None: - """在 app.py 末尾调用(login_required 已定义后)。仅注册 POST API;页面由各 app 的 render_main_page 渲染。""" - from strategy_config import build_strategy_config - - build_kw.pop("render_trend_page", None) - attach_strategy_templates(app, repo_root) - cfg = build_strategy_config(app_module, **build_kw) - register_strategy_trading(app, cfg) - from strategy_records_register import register_strategy_records - - register_strategy_records(app, cfg) - app.extensions["strategy_roll_cfg"] = cfg - _dedupe_strategy_snapshots_on_startup(cfg) - - -def attach_strategy_templates(app: Flask, repo_root: str) -> None: - strat_dir = os.path.join(repo_root, "strategy_templates") - if not os.path.isdir(strat_dir): - return - existing = app.jinja_loader - loaders = [FileSystemLoader(strat_dir)] - if existing is not None: - if isinstance(existing, ChoiceLoader): - loaders = list(existing.loaders) + loaders - else: - loaders.insert(0, existing) - app.jinja_loader = ChoiceLoader(loaders) - - -def register_strategy_trading(app: Flask, cfg: dict[str, Any]) -> None: - """cfg 由各市面 app 注入回调(仅 API / DB 差异)。""" - - login_required = cfg["login_required"] - - def _lr(f): - return login_required(f) - - @_lr - @app.route("/strategy/roll/preview", methods=["POST"]) - def strategy_roll_preview(): - data = request.get_json(silent=True) or request.form - err = _roll_preview_response(cfg, data, json_mode=request.is_json) - if request.is_json: - return jsonify(err) - if err.get("ok"): - p = err["preview"] - flash( - f"预览:约 {p.get('add_amount_display', '-')} 张," - f"合并均价 {p.get('avg_entry_after', '-')}," - f"打到止损约 {p.get('loss_at_sl_usdt', '-')}U" - ) - else: - flash(err.get("msg") or "预览失败") - return redirect(url_for("strategy_trading_page")) - - @_lr - @app.route("/strategy/roll/execute", methods=["POST"]) - def strategy_roll_execute(): - data = request.form - try: - ok, msg = _roll_execute(cfg, data) - except Exception as e: - fe = cfg.get("friendly_error") - msg = fe(e) if callable(fe) else str(e) - ok = False - flash(msg) - return redirect(url_for("strategy_trading_page")) - - @_lr - @app.route("/strategy/roll/cancel/", methods=["POST"]) - def strategy_roll_cancel_leg(leg_id: int): - conn = cfg["get_db"]() - try: - init_strategy_tables(conn) - ok, msg = cancel_roll_pending_leg(cfg, conn, leg_id) - finally: - conn.close() - if request.is_json: - return jsonify({"ok": ok, "msg": msg}) - flash(msg) - return redirect(url_for("strategy_trading_page")) - - @_lr - @app.route("/strategy/roll/docs") - def strategy_roll_docs(): - path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "顺势加仓滚仓说明.md") - if not os.path.isfile(path): - flash("滚仓说明文档不存在") - return redirect(url_for("strategy_trading_page")) - with open(path, encoding="utf-8") as f: - raw = f.read() - return render_template( - "strategy_roll_docs.html", - doc_html=_roll_doc_markdown_to_html(raw), - exchange_display=cfg.get("exchange_display") or "", - ) - - -def _roll_doc_markdown_to_html(text: str) -> str: - """轻量 Markdown → HTML(仅供滚仓说明页)。""" - lines = text.splitlines() - out: list[str] = [] - i = 0 - in_code = False - code_buf: list[str] = [] - - def flush_code() -> None: - nonlocal code_buf - if code_buf: - out.append( - "
"
-                + html_module.escape("\n".join(code_buf))
-                + "
" - ) - code_buf = [] - - def inline_md(s: str) -> str: - s = html_module.escape(s) - s = re.sub(r"`([^`]+)`", r"\1", s) - s = re.sub(r"\*\*([^*]+)\*\*", r"\1", s) - return s - - while i < len(lines): - line = lines[i] - if line.strip().startswith("```"): - if in_code: - in_code = False - flush_code() - else: - in_code = True - i += 1 - continue - if in_code: - code_buf.append(line) - i += 1 - continue - if line.startswith("# "): - out.append(f"

{inline_md(line[2:].strip())}

") - elif line.startswith("## "): - out.append(f"

{inline_md(line[3:].strip())}

") - elif line.startswith("### "): - out.append(f"

{inline_md(line[4:].strip())}

") - elif line.strip() == "---": - out.append("
") - elif line.startswith("|") and "|" in line[1:]: - rows: list[str] = [] - while i < len(lines) and lines[i].startswith("|"): - rows.append(lines[i]) - i += 1 - if len(rows) >= 2 and re.match(r"^\|[\s\-:|]+\|$", rows[1].strip()): - out.append("") - hdr = [c.strip() for c in rows[0].strip("|").split("|")] - out.append("" + "".join(f"" for c in hdr) + "") - for row in rows[2:]: - cells = [c.strip() for c in row.strip("|").split("|")] - out.append("" + "".join(f"" for c in cells) + "") - out.append("
{inline_md(c)}
{inline_md(c)}
") - continue - elif re.match(r"^[-*]\s+", line): - out.append("
    ") - while i < len(lines) and re.match(r"^[-*]\s+", lines[i]): - item = re.sub(r"^[-*]\s+", "", lines[i]) - out.append(f"
  • {inline_md(item)}
  • ") - i += 1 - out.append("
") - continue - elif line.strip(): - out.append(f"

{inline_md(line.strip())}

") - i += 1 - flush_code() - return "\n".join(out) - - -def _row_to_dict(row) -> dict: - if row is None: - return {} - try: - return dict(row) - except Exception: - return {} - - -def _count_active_trends(conn, cfg: dict) -> int: - fn = cfg.get("count_active_trend_plans") - if callable(fn): - return int(fn(conn) or 0) - try: - return int( - conn.execute( - "SELECT COUNT(*) FROM trend_pullback_plans WHERE status='active'" - ).fetchone()[0] - ) - except Exception: - return 0 - - -def _risk_from_monitor(mon: dict, cfg: dict) -> tuple[Optional[float], Optional[str]]: - try: - rp = float(mon.get("risk_percent") or cfg.get("default_risk_percent", 2)) - except (TypeError, ValueError): - return None, "监控单风险%无效" - if rp <= 0: - return None, "监控单风险%须大于0" - return rp, None - - -def _contract_size(cfg: dict, ex_sym: str) -> float: - get_cs = cfg.get("get_contract_size") - if callable(get_cs): - try: - return float(get_cs(ex_sym) or 1.0) - except Exception: - pass - return 1.0 - - -def _roll_context(cfg: dict, data: dict) -> tuple[Optional[dict], Optional[str]]: - m = cfg.get("app_module") - if m is not None: - try: - from position_sizing_lib import OPEN_SOURCE_ROLL, assert_open_source_allowed - - mode = getattr(m, "POSITION_SIZING_MODE", None) or "risk" - ok_src, src_msg = assert_open_source_allowed(mode, OPEN_SOURCE_ROLL) - if not ok_src: - return None, src_msg - except Exception: - pass - get_db = cfg["get_db"] - symbol = cfg["normalize_symbol_input"](data.get("symbol") or "") - if not symbol: - return None, "请选择或填写币种" - direction = (data.get("direction") or "long").strip().lower() - ex_sym = cfg["normalize_exchange_symbol"](symbol) - conn = get_db() - init_strategy_tables(conn) - if _count_active_trends(conn, cfg) > 0: - conn.close() - return None, "存在运行中的趋势回调计划,请先结束后再滚仓" - mon = _get_active_monitor(conn, cfg, symbol, direction) - if not mon: - conn.close() - return None, "未找到该币种同向的下单监控持仓,请先在「实盘下单」开仓" - rg, legs_done, pending, roll_is_new = _get_or_create_roll_group_meta(conn, mon) - if pending > 0: - conn.close() - return None, "已有监控中的滚仓腿,请等待成交/失效或先删除后再提交" - conn_cap = get_db() - try: - capital = float(cfg["get_trading_capital_usdt"](conn_cap)) - finally: - conn_cap.close() - risk_pct, risk_err = _risk_from_monitor(mon, cfg) - if risk_err: - conn.close() - return None, risk_err - pos = cfg["get_position"](ex_sym, direction) - qty = float(pos.get("contracts") or 0) - if qty <= 0: - conn.close() - return None, "交易所无该方向持仓,无法滚仓" - entry = float(pos.get("entry_price") or mon.get("trigger_price") or 0) - if entry <= 0: - conn.close() - return None, "无法获取持仓均价" - mark_fn = cfg.get("get_mark_price") or cfg.get("get_price") - mark = mark_fn(symbol) if callable(mark_fn) else cfg["get_price"](symbol) - ctx = { - "conn": conn, - "mon": mon, - "rg": rg, - "legs_done": legs_done, - "symbol": symbol, - "direction": direction, - "ex_sym": ex_sym, - "qty": qty, - "entry": entry, - "mark": float(mark) if mark else None, - "capital": capital, - "risk_pct": float(risk_pct), - "tp0": float(mon.get("take_profit") or rg.get("initial_take_profit") or 0), - "contract_size": _contract_size(cfg, ex_sym), - } - return ctx, None - - -def _parse_roll_form(data: dict, ctx: dict) -> tuple[Optional[dict], Optional[str]]: - add_mode = (data.get("add_mode") or MARKET_MODE).strip().lower() - raw_sl = data.get("new_stop_loss") or data.get("sl") - if raw_sl in (None, ""): - return None, "请填写新止损价" - try: - new_sl = float(raw_sl) - except (TypeError, ValueError): - return None, "止损价格式错误" - if new_sl <= 0: - return None, "止损价须大于0" - fib_u = fib_l = bp = None - try: - if data.get("fib_upper") not in (None, ""): - fib_u = float(data.get("fib_upper")) - if data.get("fib_lower") not in (None, ""): - fib_l = float(data.get("fib_lower")) - if data.get("breakthrough_price") not in (None, ""): - bp = float(data.get("breakthrough_price")) - except (TypeError, ValueError): - return None, "价格参数格式错误" - - add_price = ctx.get("mark") - if add_mode == MARKET_MODE: - if add_price is None or add_price <= 0: - return None, "无法获取市价快照" - elif add_mode in FIB_MODES: - if fib_u is None or fib_l is None: - return None, "斐波须填写上沿 H 与下沿 L" - elif add_mode == BREAKOUT_MODE: - if bp is None: - return None, "突破加仓须填写突破价" - add_price = ctx.get("mark") - else: - return None, "加仓方式无效" - - return { - "add_mode": add_mode, - "new_stop_loss": new_sl, - "fib_upper": fib_u, - "fib_lower": fib_l, - "breakthrough_price": bp, - "add_price": add_price, - }, None - - -def _roll_preview_response(cfg: dict, data: dict, json_mode: bool = False) -> dict: - ctx, err = _roll_context(cfg, data) - if err: - return {"ok": False, "msg": err} - parsed, perr = _parse_roll_form(data, ctx) - if perr: - ctx["conn"].close() - return {"ok": False, "msg": perr} - conn = ctx["conn"] - try: - preview, perr2 = preview_roll( - direction=ctx["direction"], - symbol=ctx["symbol"], - qty_existing=ctx["qty"], - entry_existing=ctx["entry"], - initial_take_profit=ctx["tp0"], - add_mode=parsed["add_mode"], - new_stop_loss=parsed["new_stop_loss"], - risk_percent=ctx["risk_pct"], - capital_base_usdt=ctx["capital"], - add_price=parsed["add_price"], - fib_upper=parsed["fib_upper"], - fib_lower=parsed["fib_lower"], - breakthrough_price=parsed["breakthrough_price"], - legs_done=ctx["legs_done"], - contract_size=ctx["contract_size"], - ) - finally: - conn.close() - if perr2: - return {"ok": False, "msg": perr2} - amt_raw = float(preview["add_amount_raw"]) - amt_p = cfg["amount_to_precision"](ctx["ex_sym"], amt_raw) - preview["add_amount_display"] = amt_p if amt_p is not None else amt_raw - preview["risk_display"] = f"{ctx['risk_pct']:g}%≈{ctx['capital'] * ctx['risk_pct'] / 100:.2f}U" - price_fmt = cfg.get("price_fmt") - if callable(price_fmt): - preview["add_price_display"] = price_fmt(ctx["symbol"], preview["add_price"]) - preview["new_sl_display"] = price_fmt(ctx["symbol"], preview["new_stop_loss"]) - preview["tp_display"] = price_fmt(ctx["symbol"], preview["initial_take_profit"]) - return {"ok": True, "preview": preview} - - -def _roll_execute(cfg: dict, data: dict) -> tuple[bool, str]: - get_db = cfg["get_db"] - conn = None - try: - ok_live, reason = cfg["ensure_live_ready"]() - if not ok_live: - return False, reason or "实盘未就绪" - prev = _roll_preview_response(cfg, data) - if not prev.get("ok"): - return False, prev.get("msg") or "预览失败" - preview = prev["preview"] - symbol = cfg["normalize_symbol_input"](data.get("symbol") or "") - direction = preview["direction"] - ex_sym = cfg["normalize_exchange_symbol"](symbol) - add_mode = preview["add_mode"] - new_sl = float(preview["new_stop_loss"]) - tp0 = float(preview["initial_take_profit"]) - lev_fn = cfg.get("default_leverage") - if not callable(lev_fn): - lev_fn = lambda _s: 5 - leverage = int(data.get("leverage") or 0) or int(lev_fn(symbol)) - conn = get_db() - init_strategy_tables(conn) - mon = _get_active_monitor(conn, cfg, symbol, direction) - if not mon: - return False, "监控单已不存在" - rg, legs_done, pending, roll_is_new = _get_or_create_roll_group_meta(conn, mon) - if pending > 0: - return False, "已有监控中的滚仓腿,请先删除或等待结束" - if add_mode == MARKET_MODE: - amount = cfg["amount_to_precision"](ex_sym, float(preview["add_amount_raw"])) - if amount is None or amount <= 0: - return False, "加仓张数低于交易所最小精度" - order = cfg["market_add"](ex_sym, direction, amount, leverage) - fill = float( - cfg.get("resolve_fill_price", lambda o, s, p: p)( - order, ex_sym, preview["add_price"] - ) - or preview["add_price"] - ) - oid = str(order.get("id") or "") if isinstance(order, dict) else "" - cfg["replace_tpsl"](ex_sym, direction, new_sl, tp0, mon) - conn.execute( - """INSERT INTO roll_legs ( - roll_group_id, leg_index, add_mode, fib_upper, fib_lower, limit_price, - breakthrough_price, fill_price, amount, new_stop_loss, exchange_order_id, - status, created_at - ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?)""", - ( - rg["id"], - legs_done + 1, - preview["add_mode_label"], - preview.get("fib_upper"), - preview.get("fib_lower"), - None, - preview.get("breakthrough_price"), - fill, - amount, - new_sl, - oid, - "filled", - cfg["app_now_str"](), - ), - ) - conn.execute( - "UPDATE roll_groups SET leg_count=?, current_stop_loss=?, updated_at=? WHERE id=?", - (legs_done + 1, new_sl, cfg["app_now_str"](), rg["id"]), - ) - conn.execute( - "UPDATE order_monitors SET stop_loss=? WHERE id=?", - (new_sl, mon["id"]), - ) - conn.commit() - _maybe_notify_roll_started(cfg, rg, mon, symbol, direction, tp0, new_sl, roll_is_new=roll_is_new) - return True, f"市价加仓第 {legs_done + 1} 腿已成交,止损已更新,止盈仍为首仓" - # 程序监控:斐波 / 突破 - limit_px = None - if add_mode in FIB_MODES: - px_fn = cfg.get("price_to_precision") - limit_px = float(preview["add_price"]) - if callable(px_fn): - limit_px = float(px_fn(ex_sym, limit_px) or limit_px) - mark_fn = cfg.get("get_mark_price") or cfg.get("get_price") - last_mark = mark_fn(symbol) if callable(mark_fn) else preview["add_price"] - conn.execute( - """INSERT INTO roll_legs ( - roll_group_id, leg_index, add_mode, fib_upper, fib_lower, limit_price, - breakthrough_price, new_stop_loss, last_mark_price, status, created_at - ) VALUES (?,?,?,?,?,?,?,?,?,?,?)""", - ( - rg["id"], - legs_done + 1, - preview["add_mode_label"], - preview.get("fib_upper"), - preview.get("fib_lower"), - limit_px, - preview.get("breakthrough_price"), - new_sl, - last_mark, - "pending", - cfg["app_now_str"](), - ), - ) - conn.commit() - _maybe_notify_roll_started(cfg, rg, mon, symbol, direction, tp0, new_sl, roll_is_new=roll_is_new) - return True, f"已提交{preview['add_mode_label']}监控,触价后将市价加仓并更新止损" - except Exception as e: - fe = cfg.get("friendly_error") - return False, fe(e) if callable(fe) else str(e) - finally: - if conn is not None: - try: - conn.close() - except Exception: - pass - - -def _maybe_notify_roll_started(cfg, rg, mon, symbol, direction, tp0, new_sl, *, roll_is_new: bool) -> None: - if not roll_is_new: - return - try: - from strategy_wechat_notify import notify_roll_group_started - - notify_roll_group_started( - cfg, - group_id=int(rg["id"]), - symbol=symbol, - direction=direction, - order_monitor_id=int(mon["id"]), - initial_take_profit=tp0, - initial_stop_loss=float(mon.get("stop_loss") or new_sl), - ) - except Exception: - pass - - -def _get_active_monitor(conn, cfg: dict, symbol: str, direction: str) -> Optional[dict]: - row = conn.execute( - "SELECT * FROM order_monitors WHERE status='active' AND symbol=? AND direction=? ORDER BY id DESC LIMIT 1", - (symbol, direction), - ).fetchone() - return _row_to_dict(row) if row else None - - -def _get_or_create_roll_group_meta(conn, mon: dict) -> tuple[dict, int, int, bool]: - """返回 (roll_group, filled_legs, pending_legs, is_new_group)。""" - row = conn.execute( - "SELECT * FROM roll_groups WHERE order_monitor_id=? AND status='active' ORDER BY id DESC LIMIT 1", - (mon["id"],), - ).fetchone() - if row: - d = _row_to_dict(row) - gid = int(d["id"]) - filled = count_filled_roll_legs(conn, gid) - pending = count_pending_roll_legs(conn, gid) - return d, filled, pending, False - now = mon.get("created_at") or "" - cur = conn.execute( - """INSERT INTO roll_groups ( - order_monitor_id, symbol, exchange_symbol, direction, - initial_take_profit, initial_stop_loss, current_stop_loss, - risk_percent, leg_count, status, created_at, updated_at - ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?)""", - ( - mon["id"], - mon["symbol"], - mon.get("exchange_symbol"), - mon["direction"], - mon.get("take_profit"), - mon.get("stop_loss"), - mon.get("stop_loss"), - mon.get("risk_percent") or 2, - 0, - "active", - now, - now, - ), - ) - gid = int(cur.lastrowid) - return ( - { - "id": gid, - "leg_count": 0, - "initial_take_profit": mon.get("take_profit"), - "initial_stop_loss": mon.get("stop_loss"), - "symbol": mon.get("symbol"), - "direction": mon.get("direction"), - }, - 0, - 0, - True, - ) - - -def roll_sync_after_external_close(cfg: dict, conn, symbol: str, direction: str) -> dict: - """供 hub / del_order 调用的滚仓同步入口。""" - return sync_roll_after_external_close( - cfg, conn, symbol, direction, reason="手动平仓,滚仓监控已结束" - ) - +"""策略交易:Flask 路由注册(顺势加仓 + 趋势回调页)。逻辑在 strategy_*_lib。""" +from __future__ import annotations + +from lib.paths import strategy_templates_dir + +import html as html_module +import os +import re +from typing import Any, Optional + +from flask import Flask, flash, jsonify, redirect, render_template, request, url_for +from jinja2 import ChoiceLoader, FileSystemLoader + +from lib.strategy.strategy_db import init_strategy_tables +from lib.strategy.strategy_roll_lib import BREAKOUT_MODE, FIB_MODES, MARKET_MODE, preview_roll +from lib.strategy.strategy_roll_monitor_lib import ( + cancel_roll_pending_leg, + count_filled_roll_legs, + count_pending_roll_legs, + sync_roll_after_external_close, +) + + +def _dedupe_strategy_snapshots_on_startup(cfg: dict[str, Any]) -> None: + """启动时清理历史重复快照(同计划同结果仅保留最新一条)。""" + get_db = cfg.get("get_db") + if not callable(get_db): + return + try: + from lib.strategy.strategy_snapshot_lib import dedupe_strategy_snapshots + + conn = get_db() + try: + removed = dedupe_strategy_snapshots(conn) + if removed: + conn.commit() + print( + f"[strategy] deduped {removed} duplicate strategy_trade_snapshots", + flush=True, + ) + finally: + conn.close() + except Exception as e: + print(f"[strategy] snapshot dedupe skipped: {e}", flush=True) + + +def install_strategy_trading(app: Flask, repo_root: str, app_module: Any = None, **build_kw) -> None: + """在 app.py 末尾调用(login_required 已定义后)。仅注册 POST API;页面由各 app 的 render_main_page 渲染。""" + from lib.strategy.strategy_config import build_strategy_config + + build_kw.pop("render_trend_page", None) + attach_strategy_templates(app, repo_root) + cfg = build_strategy_config(app_module, **build_kw) + register_strategy_trading(app, cfg) + from lib.strategy.strategy_records_register import register_strategy_records + + register_strategy_records(app, cfg) + app.extensions["strategy_roll_cfg"] = cfg + _dedupe_strategy_snapshots_on_startup(cfg) + + +def attach_strategy_templates(app: Flask, repo_root: str) -> None: + strat_dir = strategy_templates_dir(repo_root) + if not os.path.isdir(strat_dir): + return + existing = app.jinja_loader + loaders = [FileSystemLoader(strat_dir)] + if existing is not None: + if isinstance(existing, ChoiceLoader): + loaders = list(existing.loaders) + loaders + else: + loaders.insert(0, existing) + app.jinja_loader = ChoiceLoader(loaders) + + +def register_strategy_trading(app: Flask, cfg: dict[str, Any]) -> None: + """cfg 由各市面 app 注入回调(仅 API / DB 差异)。""" + + login_required = cfg["login_required"] + + def _lr(f): + return login_required(f) + + @_lr + @app.route("/strategy/roll/preview", methods=["POST"]) + def strategy_roll_preview(): + data = request.get_json(silent=True) or request.form + err = _roll_preview_response(cfg, data, json_mode=request.is_json) + if request.is_json: + return jsonify(err) + if err.get("ok"): + p = err["preview"] + flash( + f"预览:约 {p.get('add_amount_display', '-')} 张," + f"合并均价 {p.get('avg_entry_after', '-')}," + f"打到止损约 {p.get('loss_at_sl_usdt', '-')}U" + ) + else: + flash(err.get("msg") or "预览失败") + return redirect(url_for("strategy_trading_page")) + + @_lr + @app.route("/strategy/roll/execute", methods=["POST"]) + def strategy_roll_execute(): + data = request.form + try: + ok, msg = _roll_execute(cfg, data) + except Exception as e: + fe = cfg.get("friendly_error") + msg = fe(e) if callable(fe) else str(e) + ok = False + flash(msg) + return redirect(url_for("strategy_trading_page")) + + @_lr + @app.route("/strategy/roll/cancel/", methods=["POST"]) + def strategy_roll_cancel_leg(leg_id: int): + conn = cfg["get_db"]() + try: + init_strategy_tables(conn) + ok, msg = cancel_roll_pending_leg(cfg, conn, leg_id) + finally: + conn.close() + if request.is_json: + return jsonify({"ok": ok, "msg": msg}) + flash(msg) + return redirect(url_for("strategy_trading_page")) + + @_lr + @app.route("/strategy/roll/docs") + def strategy_roll_docs(): + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "顺势加仓滚仓说明.md") + if not os.path.isfile(path): + flash("滚仓说明文档不存在") + return redirect(url_for("strategy_trading_page")) + with open(path, encoding="utf-8") as f: + raw = f.read() + return render_template( + "strategy_roll_docs.html", + doc_html=_roll_doc_markdown_to_html(raw), + exchange_display=cfg.get("exchange_display") or "", + ) + + +def _roll_doc_markdown_to_html(text: str) -> str: + """轻量 Markdown → HTML(仅供滚仓说明页)。""" + lines = text.splitlines() + out: list[str] = [] + i = 0 + in_code = False + code_buf: list[str] = [] + + def flush_code() -> None: + nonlocal code_buf + if code_buf: + out.append( + "
"
+                + html_module.escape("\n".join(code_buf))
+                + "
" + ) + code_buf = [] + + def inline_md(s: str) -> str: + s = html_module.escape(s) + s = re.sub(r"`([^`]+)`", r"\1", s) + s = re.sub(r"\*\*([^*]+)\*\*", r"\1", s) + return s + + while i < len(lines): + line = lines[i] + if line.strip().startswith("```"): + if in_code: + in_code = False + flush_code() + else: + in_code = True + i += 1 + continue + if in_code: + code_buf.append(line) + i += 1 + continue + if line.startswith("# "): + out.append(f"

{inline_md(line[2:].strip())}

") + elif line.startswith("## "): + out.append(f"

{inline_md(line[3:].strip())}

") + elif line.startswith("### "): + out.append(f"

{inline_md(line[4:].strip())}

") + elif line.strip() == "---": + out.append("
") + elif line.startswith("|") and "|" in line[1:]: + rows: list[str] = [] + while i < len(lines) and lines[i].startswith("|"): + rows.append(lines[i]) + i += 1 + if len(rows) >= 2 and re.match(r"^\|[\s\-:|]+\|$", rows[1].strip()): + out.append("") + hdr = [c.strip() for c in rows[0].strip("|").split("|")] + out.append("" + "".join(f"" for c in hdr) + "") + for row in rows[2:]: + cells = [c.strip() for c in row.strip("|").split("|")] + out.append("" + "".join(f"" for c in cells) + "") + out.append("
{inline_md(c)}
{inline_md(c)}
") + continue + elif re.match(r"^[-*]\s+", line): + out.append("
    ") + while i < len(lines) and re.match(r"^[-*]\s+", lines[i]): + item = re.sub(r"^[-*]\s+", "", lines[i]) + out.append(f"
  • {inline_md(item)}
  • ") + i += 1 + out.append("
") + continue + elif line.strip(): + out.append(f"

{inline_md(line.strip())}

") + i += 1 + flush_code() + return "\n".join(out) + + +def _row_to_dict(row) -> dict: + if row is None: + return {} + try: + return dict(row) + except Exception: + return {} + + +def _count_active_trends(conn, cfg: dict) -> int: + fn = cfg.get("count_active_trend_plans") + if callable(fn): + return int(fn(conn) or 0) + try: + return int( + conn.execute( + "SELECT COUNT(*) FROM trend_pullback_plans WHERE status='active'" + ).fetchone()[0] + ) + except Exception: + return 0 + + +def _risk_from_monitor(mon: dict, cfg: dict) -> tuple[Optional[float], Optional[str]]: + try: + rp = float(mon.get("risk_percent") or cfg.get("default_risk_percent", 2)) + except (TypeError, ValueError): + return None, "监控单风险%无效" + if rp <= 0: + return None, "监控单风险%须大于0" + return rp, None + + +def _contract_size(cfg: dict, ex_sym: str) -> float: + get_cs = cfg.get("get_contract_size") + if callable(get_cs): + try: + return float(get_cs(ex_sym) or 1.0) + except Exception: + pass + return 1.0 + + +def _roll_context(cfg: dict, data: dict) -> tuple[Optional[dict], Optional[str]]: + m = cfg.get("app_module") + if m is not None: + try: + from lib.trade.position_sizing_lib import OPEN_SOURCE_ROLL, assert_open_source_allowed + + mode = getattr(m, "POSITION_SIZING_MODE", None) or "risk" + ok_src, src_msg = assert_open_source_allowed(mode, OPEN_SOURCE_ROLL) + if not ok_src: + return None, src_msg + except Exception: + pass + get_db = cfg["get_db"] + symbol = cfg["normalize_symbol_input"](data.get("symbol") or "") + if not symbol: + return None, "请选择或填写币种" + direction = (data.get("direction") or "long").strip().lower() + ex_sym = cfg["normalize_exchange_symbol"](symbol) + conn = get_db() + init_strategy_tables(conn) + if _count_active_trends(conn, cfg) > 0: + conn.close() + return None, "存在运行中的趋势回调计划,请先结束后再滚仓" + mon = _get_active_monitor(conn, cfg, symbol, direction) + if not mon: + conn.close() + return None, "未找到该币种同向的下单监控持仓,请先在「实盘下单」开仓" + rg, legs_done, pending, roll_is_new = _get_or_create_roll_group_meta(conn, mon) + if pending > 0: + conn.close() + return None, "已有监控中的滚仓腿,请等待成交/失效或先删除后再提交" + conn_cap = get_db() + try: + capital = float(cfg["get_trading_capital_usdt"](conn_cap)) + finally: + conn_cap.close() + risk_pct, risk_err = _risk_from_monitor(mon, cfg) + if risk_err: + conn.close() + return None, risk_err + pos = cfg["get_position"](ex_sym, direction) + qty = float(pos.get("contracts") or 0) + if qty <= 0: + conn.close() + return None, "交易所无该方向持仓,无法滚仓" + entry = float(pos.get("entry_price") or mon.get("trigger_price") or 0) + if entry <= 0: + conn.close() + return None, "无法获取持仓均价" + mark_fn = cfg.get("get_mark_price") or cfg.get("get_price") + mark = mark_fn(symbol) if callable(mark_fn) else cfg["get_price"](symbol) + ctx = { + "conn": conn, + "mon": mon, + "rg": rg, + "legs_done": legs_done, + "symbol": symbol, + "direction": direction, + "ex_sym": ex_sym, + "qty": qty, + "entry": entry, + "mark": float(mark) if mark else None, + "capital": capital, + "risk_pct": float(risk_pct), + "tp0": float(mon.get("take_profit") or rg.get("initial_take_profit") or 0), + "contract_size": _contract_size(cfg, ex_sym), + } + return ctx, None + + +def _parse_roll_form(data: dict, ctx: dict) -> tuple[Optional[dict], Optional[str]]: + add_mode = (data.get("add_mode") or MARKET_MODE).strip().lower() + raw_sl = data.get("new_stop_loss") or data.get("sl") + if raw_sl in (None, ""): + return None, "请填写新止损价" + try: + new_sl = float(raw_sl) + except (TypeError, ValueError): + return None, "止损价格式错误" + if new_sl <= 0: + return None, "止损价须大于0" + fib_u = fib_l = bp = None + try: + if data.get("fib_upper") not in (None, ""): + fib_u = float(data.get("fib_upper")) + if data.get("fib_lower") not in (None, ""): + fib_l = float(data.get("fib_lower")) + if data.get("breakthrough_price") not in (None, ""): + bp = float(data.get("breakthrough_price")) + except (TypeError, ValueError): + return None, "价格参数格式错误" + + add_price = ctx.get("mark") + if add_mode == MARKET_MODE: + if add_price is None or add_price <= 0: + return None, "无法获取市价快照" + elif add_mode in FIB_MODES: + if fib_u is None or fib_l is None: + return None, "斐波须填写上沿 H 与下沿 L" + elif add_mode == BREAKOUT_MODE: + if bp is None: + return None, "突破加仓须填写突破价" + add_price = ctx.get("mark") + else: + return None, "加仓方式无效" + + return { + "add_mode": add_mode, + "new_stop_loss": new_sl, + "fib_upper": fib_u, + "fib_lower": fib_l, + "breakthrough_price": bp, + "add_price": add_price, + }, None + + +def _roll_preview_response(cfg: dict, data: dict, json_mode: bool = False) -> dict: + ctx, err = _roll_context(cfg, data) + if err: + return {"ok": False, "msg": err} + parsed, perr = _parse_roll_form(data, ctx) + if perr: + ctx["conn"].close() + return {"ok": False, "msg": perr} + conn = ctx["conn"] + try: + preview, perr2 = preview_roll( + direction=ctx["direction"], + symbol=ctx["symbol"], + qty_existing=ctx["qty"], + entry_existing=ctx["entry"], + initial_take_profit=ctx["tp0"], + add_mode=parsed["add_mode"], + new_stop_loss=parsed["new_stop_loss"], + risk_percent=ctx["risk_pct"], + capital_base_usdt=ctx["capital"], + add_price=parsed["add_price"], + fib_upper=parsed["fib_upper"], + fib_lower=parsed["fib_lower"], + breakthrough_price=parsed["breakthrough_price"], + legs_done=ctx["legs_done"], + contract_size=ctx["contract_size"], + ) + finally: + conn.close() + if perr2: + return {"ok": False, "msg": perr2} + amt_raw = float(preview["add_amount_raw"]) + amt_p = cfg["amount_to_precision"](ctx["ex_sym"], amt_raw) + preview["add_amount_display"] = amt_p if amt_p is not None else amt_raw + preview["risk_display"] = f"{ctx['risk_pct']:g}%≈{ctx['capital'] * ctx['risk_pct'] / 100:.2f}U" + price_fmt = cfg.get("price_fmt") + if callable(price_fmt): + preview["add_price_display"] = price_fmt(ctx["symbol"], preview["add_price"]) + preview["new_sl_display"] = price_fmt(ctx["symbol"], preview["new_stop_loss"]) + preview["tp_display"] = price_fmt(ctx["symbol"], preview["initial_take_profit"]) + return {"ok": True, "preview": preview} + + +def _roll_execute(cfg: dict, data: dict) -> tuple[bool, str]: + get_db = cfg["get_db"] + conn = None + try: + ok_live, reason = cfg["ensure_live_ready"]() + if not ok_live: + return False, reason or "实盘未就绪" + prev = _roll_preview_response(cfg, data) + if not prev.get("ok"): + return False, prev.get("msg") or "预览失败" + preview = prev["preview"] + symbol = cfg["normalize_symbol_input"](data.get("symbol") or "") + direction = preview["direction"] + ex_sym = cfg["normalize_exchange_symbol"](symbol) + add_mode = preview["add_mode"] + new_sl = float(preview["new_stop_loss"]) + tp0 = float(preview["initial_take_profit"]) + lev_fn = cfg.get("default_leverage") + if not callable(lev_fn): + lev_fn = lambda _s: 5 + leverage = int(data.get("leverage") or 0) or int(lev_fn(symbol)) + conn = get_db() + init_strategy_tables(conn) + mon = _get_active_monitor(conn, cfg, symbol, direction) + if not mon: + return False, "监控单已不存在" + rg, legs_done, pending, roll_is_new = _get_or_create_roll_group_meta(conn, mon) + if pending > 0: + return False, "已有监控中的滚仓腿,请先删除或等待结束" + if add_mode == MARKET_MODE: + amount = cfg["amount_to_precision"](ex_sym, float(preview["add_amount_raw"])) + if amount is None or amount <= 0: + return False, "加仓张数低于交易所最小精度" + order = cfg["market_add"](ex_sym, direction, amount, leverage) + fill = float( + cfg.get("resolve_fill_price", lambda o, s, p: p)( + order, ex_sym, preview["add_price"] + ) + or preview["add_price"] + ) + oid = str(order.get("id") or "") if isinstance(order, dict) else "" + cfg["replace_tpsl"](ex_sym, direction, new_sl, tp0, mon) + conn.execute( + """INSERT INTO roll_legs ( + roll_group_id, leg_index, add_mode, fib_upper, fib_lower, limit_price, + breakthrough_price, fill_price, amount, new_stop_loss, exchange_order_id, + status, created_at + ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?)""", + ( + rg["id"], + legs_done + 1, + preview["add_mode_label"], + preview.get("fib_upper"), + preview.get("fib_lower"), + None, + preview.get("breakthrough_price"), + fill, + amount, + new_sl, + oid, + "filled", + cfg["app_now_str"](), + ), + ) + conn.execute( + "UPDATE roll_groups SET leg_count=?, current_stop_loss=?, updated_at=? WHERE id=?", + (legs_done + 1, new_sl, cfg["app_now_str"](), rg["id"]), + ) + conn.execute( + "UPDATE order_monitors SET stop_loss=? WHERE id=?", + (new_sl, mon["id"]), + ) + conn.commit() + _maybe_notify_roll_started(cfg, rg, mon, symbol, direction, tp0, new_sl, roll_is_new=roll_is_new) + return True, f"市价加仓第 {legs_done + 1} 腿已成交,止损已更新,止盈仍为首仓" + # 程序监控:斐波 / 突破 + limit_px = None + if add_mode in FIB_MODES: + px_fn = cfg.get("price_to_precision") + limit_px = float(preview["add_price"]) + if callable(px_fn): + limit_px = float(px_fn(ex_sym, limit_px) or limit_px) + mark_fn = cfg.get("get_mark_price") or cfg.get("get_price") + last_mark = mark_fn(symbol) if callable(mark_fn) else preview["add_price"] + conn.execute( + """INSERT INTO roll_legs ( + roll_group_id, leg_index, add_mode, fib_upper, fib_lower, limit_price, + breakthrough_price, new_stop_loss, last_mark_price, status, created_at + ) VALUES (?,?,?,?,?,?,?,?,?,?,?)""", + ( + rg["id"], + legs_done + 1, + preview["add_mode_label"], + preview.get("fib_upper"), + preview.get("fib_lower"), + limit_px, + preview.get("breakthrough_price"), + new_sl, + last_mark, + "pending", + cfg["app_now_str"](), + ), + ) + conn.commit() + _maybe_notify_roll_started(cfg, rg, mon, symbol, direction, tp0, new_sl, roll_is_new=roll_is_new) + return True, f"已提交{preview['add_mode_label']}监控,触价后将市价加仓并更新止损" + except Exception as e: + fe = cfg.get("friendly_error") + return False, fe(e) if callable(fe) else str(e) + finally: + if conn is not None: + try: + conn.close() + except Exception: + pass + + +def _maybe_notify_roll_started(cfg, rg, mon, symbol, direction, tp0, new_sl, *, roll_is_new: bool) -> None: + if not roll_is_new: + return + try: + from lib.strategy.strategy_wechat_notify import notify_roll_group_started + + notify_roll_group_started( + cfg, + group_id=int(rg["id"]), + symbol=symbol, + direction=direction, + order_monitor_id=int(mon["id"]), + initial_take_profit=tp0, + initial_stop_loss=float(mon.get("stop_loss") or new_sl), + ) + except Exception: + pass + + +def _get_active_monitor(conn, cfg: dict, symbol: str, direction: str) -> Optional[dict]: + row = conn.execute( + "SELECT * FROM order_monitors WHERE status='active' AND symbol=? AND direction=? ORDER BY id DESC LIMIT 1", + (symbol, direction), + ).fetchone() + return _row_to_dict(row) if row else None + + +def _get_or_create_roll_group_meta(conn, mon: dict) -> tuple[dict, int, int, bool]: + """返回 (roll_group, filled_legs, pending_legs, is_new_group)。""" + row = conn.execute( + "SELECT * FROM roll_groups WHERE order_monitor_id=? AND status='active' ORDER BY id DESC LIMIT 1", + (mon["id"],), + ).fetchone() + if row: + d = _row_to_dict(row) + gid = int(d["id"]) + filled = count_filled_roll_legs(conn, gid) + pending = count_pending_roll_legs(conn, gid) + return d, filled, pending, False + now = mon.get("created_at") or "" + cur = conn.execute( + """INSERT INTO roll_groups ( + order_monitor_id, symbol, exchange_symbol, direction, + initial_take_profit, initial_stop_loss, current_stop_loss, + risk_percent, leg_count, status, created_at, updated_at + ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?)""", + ( + mon["id"], + mon["symbol"], + mon.get("exchange_symbol"), + mon["direction"], + mon.get("take_profit"), + mon.get("stop_loss"), + mon.get("stop_loss"), + mon.get("risk_percent") or 2, + 0, + "active", + now, + now, + ), + ) + gid = int(cur.lastrowid) + return ( + { + "id": gid, + "leg_count": 0, + "initial_take_profit": mon.get("take_profit"), + "initial_stop_loss": mon.get("stop_loss"), + "symbol": mon.get("symbol"), + "direction": mon.get("direction"), + }, + 0, + 0, + True, + ) + + +def roll_sync_after_external_close(cfg: dict, conn, symbol: str, direction: str) -> dict: + """供 hub / del_order 调用的滚仓同步入口。""" + return sync_roll_after_external_close( + cfg, conn, symbol, direction, reason="手动平仓,滚仓监控已结束" + ) + diff --git a/strategy_roll_lib.py b/lib/strategy/strategy_roll_lib.py similarity index 96% rename from strategy_roll_lib.py rename to lib/strategy/strategy_roll_lib.py index c7ab12a..2a1186b 100644 --- a/strategy_roll_lib.py +++ b/lib/strategy/strategy_roll_lib.py @@ -1,385 +1,385 @@ -"""顺势加仓(滚仓):纯计算。人工触发;止盈锁定首仓;程序监控触价市价成交。""" -from __future__ import annotations - -from typing import Any, Optional, Tuple - -from fib_key_monitor_lib import calc_fib_plan, fib_invalidate_by_mark - -ROLL_MAX_LEGS_LONG = 3 -ROLL_MAX_LEGS_SHORT = 3 - -MARKET_MODE = "market" -FIB_MODES = frozenset({"fib_618", "fib_786"}) -BREAKOUT_MODE = "breakout" - -MODE_LABELS = { - MARKET_MODE: "市价加仓", - "fib_618": "斐波0.618", - "fib_786": "斐波0.786", - BREAKOUT_MODE: "突破加仓", -} - - -def fib_ratio_from_mode(mode: str) -> Optional[float]: - m = (mode or "").strip().lower() - if m in ("fib_618", "618", "0.618"): - return 0.618 - if m in ("fib_786", "786", "0.786"): - return 0.786 - return None - - -def mode_label(mode: str) -> str: - m = (mode or MARKET_MODE).strip().lower() - return MODE_LABELS.get(m, m) - - -def fib_limit_entry(direction: str, upper: float, lower: float, mode: str) -> Tuple[Optional[float], Optional[str]]: - """H/L 仅用于计算限价加仓价;多:下沿=止损侧;空:上沿=止损侧。""" - ratio = fib_ratio_from_mode(mode) - if ratio is None: - return None, "斐波档位无效" - h, l = float(upper), float(lower) - if h <= l: - return None, "上沿须大于下沿" - direction = (direction or "long").strip().lower() - if direction == "short": - plan = calc_fib_plan("short", h, l, ratio) - else: - plan = calc_fib_plan("long", h, l, ratio) - if not plan: - return None, "无法计算斐波限价" - entry, _sl, _tp = plan - return float(entry), None - - -def max_roll_legs(direction: str) -> int: - return ROLL_MAX_LEGS_LONG if (direction or "long").strip().lower() == "long" else ROLL_MAX_LEGS_SHORT - - -def avg_entry_after_add( - qty_existing: float, - entry_existing: float, - add_qty: float, - add_price: float, -) -> float: - q1 = float(qty_existing) - e1 = float(entry_existing) - q2 = float(add_qty) - e2 = float(add_price) - total = q1 + q2 - if total <= 0: - return 0.0 - return (q1 * e1 + q2 * e2) / total - - -def calc_risk_budget_usdt(capital_base_usdt: float, risk_percent: float) -> float: - return float(capital_base_usdt) * (float(risk_percent) / 100.0) - - -def solve_add_amount_for_total_risk( - direction: str, - qty_existing: float, - entry_existing: float, - add_price: float, - new_stop: float, - risk_budget_usdt: float, - contract_size: float = 1.0, -) -> Tuple[Optional[float], Optional[str]]: - """ - 合并持仓打到 new_stop 时总亏损 ≈ risk_budget(方案 C)。 - long: (avg - SL) * (Q1+Q2) * cs = B => Q2 = (B/cs - Q1*(E1-SL)) / (E2-SL) - short: (SL - avg) * (Q1+Q2) * cs = B => Q2 = (B/cs - Q1*(SL-E1)) / (SL-E2) - """ - try: - q1 = float(qty_existing) - e1 = float(entry_existing) - e2 = float(add_price) - sl = float(new_stop) - b = float(risk_budget_usdt) - cs = float(contract_size) if contract_size else 1.0 - except (TypeError, ValueError): - return None, "参数格式错误" - if q1 <= 0 or e1 <= 0 or e2 <= 0 or b <= 0 or cs <= 0: - return None, "持仓或风险预算无效" - direction = (direction or "long").strip().lower() - if direction == "short": - denom = sl - e2 - numer = b / cs - q1 * (sl - e1) - if denom <= 0: - return None, "做空:新止损须高于加仓价" - else: - denom = e2 - sl - numer = b / cs - q1 * (e1 - sl) - if denom <= 0: - return None, "做多:新止损须低于加仓价" - q2 = numer / denom - if q2 <= 0: - return None, "按当前新止损与风险预算,无需加仓或无法再加(已满足风险上限)" - return q2, None - - -def loss_at_stop_usdt( - direction: str, - avg: float, - qty: float, - stop: float, - contract_size: float = 1.0, -) -> float: - cs = float(contract_size or 1.0) - direction = (direction or "long").strip().lower() - if direction == "short": - return (float(stop) - float(avg)) * float(qty) * cs - return (float(avg) - float(stop)) * float(qty) * cs - - -def reward_at_tp_usdt( - direction: str, - avg: float, - take_profit: float, - qty: float, - contract_size: float = 1.0, -) -> float: - cs = float(contract_size or 1.0) - direction = (direction or "long").strip().lower() - if direction == "short": - return (float(avg) - float(take_profit)) * float(qty) * cs - return (float(take_profit) - float(avg)) * float(qty) * cs - - -def roll_fib_trigger_crossed( - direction: str, - prev_mark: Optional[float], - mark: float, - limit_price: float, -) -> bool: - """斐波:多=向下穿越限价;空=向上穿越限价。""" - try: - m = float(mark) - lv = float(limit_price) - pm = float(prev_mark) if prev_mark is not None else None - except (TypeError, ValueError): - return False - direction = (direction or "long").strip().lower() - if direction == "long": - if pm is None: - return m <= lv - return pm > lv and m <= lv - if pm is None: - return m >= lv - return pm < lv and m >= lv - - -def roll_breakout_trigger_crossed( - direction: str, - prev_mark: Optional[float], - mark: float, - breakthrough_price: float, -) -> bool: - """突破:多=向上穿越突破价;空=向下穿越突破价。""" - try: - m = float(mark) - bp = float(breakthrough_price) - pm = float(prev_mark) if prev_mark is not None else None - except (TypeError, ValueError): - return False - direction = (direction or "long").strip().lower() - if direction == "long": - if pm is None: - return m > bp - return pm <= bp and m > bp - if pm is None: - return m < bp - return pm >= bp and m < bp - - -def roll_fib_invalidate(direction: str, mark: float, upper: float, lower: float) -> bool: - """斐波 pending 失效:止盈侧突破(多 mark>=H;空 mark<=L)。""" - return fib_invalidate_by_mark(direction, mark, upper, lower) - - -def roll_breakout_invalidate(direction: str, mark: float, stop_loss: float) -> bool: - """突破 pending 失效:未到突破价先触达止损侧(多 mark<=S;空 mark>=S)。""" - try: - m = float(mark) - sl = float(stop_loss) - except (TypeError, ValueError): - return False - direction = (direction or "long").strip().lower() - if direction == "long": - return m <= sl - return m >= sl - - -def validate_roll_geometry( - direction: str, - add_mode: str, - *, - new_stop_loss: float, - add_price: Optional[float] = None, - fib_upper: Optional[float] = None, - fib_lower: Optional[float] = None, - breakthrough_price: Optional[float] = None, - entry_existing: float = 0.0, - initial_take_profit: float = 0.0, - mark_price: Optional[float] = None, -) -> Optional[str]: - direction = (direction or "long").strip().lower() - mode = (add_mode or MARKET_MODE).strip().lower() - try: - sl = float(new_stop_loss) - tp = float(initial_take_profit) - e1 = float(entry_existing or 0) - except (TypeError, ValueError): - return "止损/止盈格式错误" - if sl <= 0 or tp <= 0: - return "止损与首仓止盈须大于0" - if direction == "long": - if e1 > 0 and tp <= e1: - return "做多:首仓止盈须高于当前持仓均价" - else: - if e1 > 0 and tp >= e1: - return "做空:首仓止盈须低于当前持仓均价" - - if mode == MARKET_MODE: - if add_price is None or float(add_price) <= 0: - return "市价加仓需要有效参考价" - entry_add = float(add_price) - elif mode in FIB_MODES: - if fib_upper is None or fib_lower is None: - return "斐波须填写上沿 H 与下沿 L" - entry_add, err = fib_limit_entry(direction, float(fib_upper), float(fib_lower), mode) - if err: - return err - if entry_add is None or entry_add <= 0: - return "无法计算斐波限价" - elif mode == BREAKOUT_MODE: - if breakthrough_price is None: - return "突破加仓须填写突破价" - try: - bp = float(breakthrough_price) - except (TypeError, ValueError): - return "突破价格式错误" - if bp <= 0: - return "突破价须大于0" - entry_add = bp - if direction == "long": - if sl >= bp: - return "做多:止损须低于突破价" - if mark_price is not None and float(mark_price) >= bp: - return "做多:当前价须低于突破价(等待向上突破)" - else: - if sl <= bp: - return "做空:止损须高于突破价" - if mark_price is not None and float(mark_price) <= bp: - return "做空:当前价须高于突破价(等待向下跌破)" - else: - return "加仓方式无效" - - if mode != BREAKOUT_MODE: - entry_add = float(entry_add) # type: ignore[arg-type] - if direction == "long": - if sl >= entry_add: - return "做多:新止损须低于加仓价" - else: - if sl <= entry_add: - return "做空:新止损须高于加仓价" - return None - - -def preview_roll( - *, - direction: str, - symbol: str, - qty_existing: float, - entry_existing: float, - initial_take_profit: float, - add_mode: str, - new_stop_loss: Optional[float] = None, - risk_percent: float, - capital_base_usdt: float, - add_price: Optional[float] = None, - fib_upper: Optional[float] = None, - fib_lower: Optional[float] = None, - breakthrough_price: Optional[float] = None, - legs_done: int = 0, - contract_size: float = 1.0, -) -> Tuple[Optional[dict[str, Any]], Optional[str]]: - direction = (direction or "long").strip().lower() - if legs_done >= max_roll_legs(direction): - return None, f"{'做多' if direction == 'long' else '做空'}滚仓已达 {max_roll_legs(direction)} 次上限" - mode = (add_mode or MARKET_MODE).strip().lower() - if new_stop_loss is None: - return None, "请填写新止损价" - try: - sl = float(new_stop_loss) - except (TypeError, ValueError): - return None, "止损价格式错误" - if sl <= 0: - return None, "止损须大于0" - - geom_err = validate_roll_geometry( - direction, - mode, - new_stop_loss=sl, - add_price=add_price, - fib_upper=fib_upper, - fib_lower=fib_lower, - breakthrough_price=breakthrough_price, - entry_existing=entry_existing, - initial_take_profit=initial_take_profit, - mark_price=add_price if mode == BREAKOUT_MODE else add_price, - ) - if geom_err: - return None, geom_err - - if mode == MARKET_MODE: - entry_add = float(add_price) # validated - elif mode in FIB_MODES: - entry_add, _ = fib_limit_entry(direction, float(fib_upper), float(fib_lower), mode) - entry_add = float(entry_add or 0) - else: - entry_add = float(breakthrough_price or 0) - - risk_budget = calc_risk_budget_usdt(capital_base_usdt, risk_percent) - q2_raw, err = solve_add_amount_for_total_risk( - direction, - qty_existing, - entry_existing, - entry_add, - sl, - risk_budget, - contract_size, - ) - if err: - return None, err - q2 = float(q2_raw) - new_qty = qty_existing + q2 - new_avg = avg_entry_after_add(qty_existing, entry_existing, q2, entry_add) - cs = float(contract_size or 1.0) - loss_sl = loss_at_stop_usdt(direction, new_avg, new_qty, sl, cs) - reward_tp = reward_at_tp_usdt(direction, new_avg, initial_take_profit, new_qty, cs) - return { - "symbol": symbol, - "direction": direction, - "add_mode": mode, - "add_mode_label": mode_label(mode), - "add_price": round(entry_add, 10), - "new_stop_loss": round(sl, 10), - "breakthrough_price": float(breakthrough_price) if breakthrough_price not in (None, "") else None, - "initial_take_profit": float(initial_take_profit), - "risk_percent": float(risk_percent), - "risk_budget_usdt": round(risk_budget, 4), - "add_amount_raw": q2, - "qty_existing": float(qty_existing), - "entry_existing": float(entry_existing), - "qty_after": new_qty, - "avg_entry_after": round(new_avg, 10), - "loss_at_sl_usdt": round(loss_sl, 4), - "reward_at_tp_usdt": round(reward_tp, 4), - "legs_done": int(legs_done), - "leg_index_next": int(legs_done) + 1, - "fib_upper": fib_upper, - "fib_lower": fib_lower, - "contract_size": cs, - }, None +"""顺势加仓(滚仓):纯计算。人工触发;止盈锁定首仓;程序监控触价市价成交。""" +from __future__ import annotations + +from typing import Any, Optional, Tuple + +from lib.key_monitor.fib_key_monitor_lib import calc_fib_plan, fib_invalidate_by_mark + +ROLL_MAX_LEGS_LONG = 3 +ROLL_MAX_LEGS_SHORT = 3 + +MARKET_MODE = "market" +FIB_MODES = frozenset({"fib_618", "fib_786"}) +BREAKOUT_MODE = "breakout" + +MODE_LABELS = { + MARKET_MODE: "市价加仓", + "fib_618": "斐波0.618", + "fib_786": "斐波0.786", + BREAKOUT_MODE: "突破加仓", +} + + +def fib_ratio_from_mode(mode: str) -> Optional[float]: + m = (mode or "").strip().lower() + if m in ("fib_618", "618", "0.618"): + return 0.618 + if m in ("fib_786", "786", "0.786"): + return 0.786 + return None + + +def mode_label(mode: str) -> str: + m = (mode or MARKET_MODE).strip().lower() + return MODE_LABELS.get(m, m) + + +def fib_limit_entry(direction: str, upper: float, lower: float, mode: str) -> Tuple[Optional[float], Optional[str]]: + """H/L 仅用于计算限价加仓价;多:下沿=止损侧;空:上沿=止损侧。""" + ratio = fib_ratio_from_mode(mode) + if ratio is None: + return None, "斐波档位无效" + h, l = float(upper), float(lower) + if h <= l: + return None, "上沿须大于下沿" + direction = (direction or "long").strip().lower() + if direction == "short": + plan = calc_fib_plan("short", h, l, ratio) + else: + plan = calc_fib_plan("long", h, l, ratio) + if not plan: + return None, "无法计算斐波限价" + entry, _sl, _tp = plan + return float(entry), None + + +def max_roll_legs(direction: str) -> int: + return ROLL_MAX_LEGS_LONG if (direction or "long").strip().lower() == "long" else ROLL_MAX_LEGS_SHORT + + +def avg_entry_after_add( + qty_existing: float, + entry_existing: float, + add_qty: float, + add_price: float, +) -> float: + q1 = float(qty_existing) + e1 = float(entry_existing) + q2 = float(add_qty) + e2 = float(add_price) + total = q1 + q2 + if total <= 0: + return 0.0 + return (q1 * e1 + q2 * e2) / total + + +def calc_risk_budget_usdt(capital_base_usdt: float, risk_percent: float) -> float: + return float(capital_base_usdt) * (float(risk_percent) / 100.0) + + +def solve_add_amount_for_total_risk( + direction: str, + qty_existing: float, + entry_existing: float, + add_price: float, + new_stop: float, + risk_budget_usdt: float, + contract_size: float = 1.0, +) -> Tuple[Optional[float], Optional[str]]: + """ + 合并持仓打到 new_stop 时总亏损 ≈ risk_budget(方案 C)。 + long: (avg - SL) * (Q1+Q2) * cs = B => Q2 = (B/cs - Q1*(E1-SL)) / (E2-SL) + short: (SL - avg) * (Q1+Q2) * cs = B => Q2 = (B/cs - Q1*(SL-E1)) / (SL-E2) + """ + try: + q1 = float(qty_existing) + e1 = float(entry_existing) + e2 = float(add_price) + sl = float(new_stop) + b = float(risk_budget_usdt) + cs = float(contract_size) if contract_size else 1.0 + except (TypeError, ValueError): + return None, "参数格式错误" + if q1 <= 0 or e1 <= 0 or e2 <= 0 or b <= 0 or cs <= 0: + return None, "持仓或风险预算无效" + direction = (direction or "long").strip().lower() + if direction == "short": + denom = sl - e2 + numer = b / cs - q1 * (sl - e1) + if denom <= 0: + return None, "做空:新止损须高于加仓价" + else: + denom = e2 - sl + numer = b / cs - q1 * (e1 - sl) + if denom <= 0: + return None, "做多:新止损须低于加仓价" + q2 = numer / denom + if q2 <= 0: + return None, "按当前新止损与风险预算,无需加仓或无法再加(已满足风险上限)" + return q2, None + + +def loss_at_stop_usdt( + direction: str, + avg: float, + qty: float, + stop: float, + contract_size: float = 1.0, +) -> float: + cs = float(contract_size or 1.0) + direction = (direction or "long").strip().lower() + if direction == "short": + return (float(stop) - float(avg)) * float(qty) * cs + return (float(avg) - float(stop)) * float(qty) * cs + + +def reward_at_tp_usdt( + direction: str, + avg: float, + take_profit: float, + qty: float, + contract_size: float = 1.0, +) -> float: + cs = float(contract_size or 1.0) + direction = (direction or "long").strip().lower() + if direction == "short": + return (float(avg) - float(take_profit)) * float(qty) * cs + return (float(take_profit) - float(avg)) * float(qty) * cs + + +def roll_fib_trigger_crossed( + direction: str, + prev_mark: Optional[float], + mark: float, + limit_price: float, +) -> bool: + """斐波:多=向下穿越限价;空=向上穿越限价。""" + try: + m = float(mark) + lv = float(limit_price) + pm = float(prev_mark) if prev_mark is not None else None + except (TypeError, ValueError): + return False + direction = (direction or "long").strip().lower() + if direction == "long": + if pm is None: + return m <= lv + return pm > lv and m <= lv + if pm is None: + return m >= lv + return pm < lv and m >= lv + + +def roll_breakout_trigger_crossed( + direction: str, + prev_mark: Optional[float], + mark: float, + breakthrough_price: float, +) -> bool: + """突破:多=向上穿越突破价;空=向下穿越突破价。""" + try: + m = float(mark) + bp = float(breakthrough_price) + pm = float(prev_mark) if prev_mark is not None else None + except (TypeError, ValueError): + return False + direction = (direction or "long").strip().lower() + if direction == "long": + if pm is None: + return m > bp + return pm <= bp and m > bp + if pm is None: + return m < bp + return pm >= bp and m < bp + + +def roll_fib_invalidate(direction: str, mark: float, upper: float, lower: float) -> bool: + """斐波 pending 失效:止盈侧突破(多 mark>=H;空 mark<=L)。""" + return fib_invalidate_by_mark(direction, mark, upper, lower) + + +def roll_breakout_invalidate(direction: str, mark: float, stop_loss: float) -> bool: + """突破 pending 失效:未到突破价先触达止损侧(多 mark<=S;空 mark>=S)。""" + try: + m = float(mark) + sl = float(stop_loss) + except (TypeError, ValueError): + return False + direction = (direction or "long").strip().lower() + if direction == "long": + return m <= sl + return m >= sl + + +def validate_roll_geometry( + direction: str, + add_mode: str, + *, + new_stop_loss: float, + add_price: Optional[float] = None, + fib_upper: Optional[float] = None, + fib_lower: Optional[float] = None, + breakthrough_price: Optional[float] = None, + entry_existing: float = 0.0, + initial_take_profit: float = 0.0, + mark_price: Optional[float] = None, +) -> Optional[str]: + direction = (direction or "long").strip().lower() + mode = (add_mode or MARKET_MODE).strip().lower() + try: + sl = float(new_stop_loss) + tp = float(initial_take_profit) + e1 = float(entry_existing or 0) + except (TypeError, ValueError): + return "止损/止盈格式错误" + if sl <= 0 or tp <= 0: + return "止损与首仓止盈须大于0" + if direction == "long": + if e1 > 0 and tp <= e1: + return "做多:首仓止盈须高于当前持仓均价" + else: + if e1 > 0 and tp >= e1: + return "做空:首仓止盈须低于当前持仓均价" + + if mode == MARKET_MODE: + if add_price is None or float(add_price) <= 0: + return "市价加仓需要有效参考价" + entry_add = float(add_price) + elif mode in FIB_MODES: + if fib_upper is None or fib_lower is None: + return "斐波须填写上沿 H 与下沿 L" + entry_add, err = fib_limit_entry(direction, float(fib_upper), float(fib_lower), mode) + if err: + return err + if entry_add is None or entry_add <= 0: + return "无法计算斐波限价" + elif mode == BREAKOUT_MODE: + if breakthrough_price is None: + return "突破加仓须填写突破价" + try: + bp = float(breakthrough_price) + except (TypeError, ValueError): + return "突破价格式错误" + if bp <= 0: + return "突破价须大于0" + entry_add = bp + if direction == "long": + if sl >= bp: + return "做多:止损须低于突破价" + if mark_price is not None and float(mark_price) >= bp: + return "做多:当前价须低于突破价(等待向上突破)" + else: + if sl <= bp: + return "做空:止损须高于突破价" + if mark_price is not None and float(mark_price) <= bp: + return "做空:当前价须高于突破价(等待向下跌破)" + else: + return "加仓方式无效" + + if mode != BREAKOUT_MODE: + entry_add = float(entry_add) # type: ignore[arg-type] + if direction == "long": + if sl >= entry_add: + return "做多:新止损须低于加仓价" + else: + if sl <= entry_add: + return "做空:新止损须高于加仓价" + return None + + +def preview_roll( + *, + direction: str, + symbol: str, + qty_existing: float, + entry_existing: float, + initial_take_profit: float, + add_mode: str, + new_stop_loss: Optional[float] = None, + risk_percent: float, + capital_base_usdt: float, + add_price: Optional[float] = None, + fib_upper: Optional[float] = None, + fib_lower: Optional[float] = None, + breakthrough_price: Optional[float] = None, + legs_done: int = 0, + contract_size: float = 1.0, +) -> Tuple[Optional[dict[str, Any]], Optional[str]]: + direction = (direction or "long").strip().lower() + if legs_done >= max_roll_legs(direction): + return None, f"{'做多' if direction == 'long' else '做空'}滚仓已达 {max_roll_legs(direction)} 次上限" + mode = (add_mode or MARKET_MODE).strip().lower() + if new_stop_loss is None: + return None, "请填写新止损价" + try: + sl = float(new_stop_loss) + except (TypeError, ValueError): + return None, "止损价格式错误" + if sl <= 0: + return None, "止损须大于0" + + geom_err = validate_roll_geometry( + direction, + mode, + new_stop_loss=sl, + add_price=add_price, + fib_upper=fib_upper, + fib_lower=fib_lower, + breakthrough_price=breakthrough_price, + entry_existing=entry_existing, + initial_take_profit=initial_take_profit, + mark_price=add_price if mode == BREAKOUT_MODE else add_price, + ) + if geom_err: + return None, geom_err + + if mode == MARKET_MODE: + entry_add = float(add_price) # validated + elif mode in FIB_MODES: + entry_add, _ = fib_limit_entry(direction, float(fib_upper), float(fib_lower), mode) + entry_add = float(entry_add or 0) + else: + entry_add = float(breakthrough_price or 0) + + risk_budget = calc_risk_budget_usdt(capital_base_usdt, risk_percent) + q2_raw, err = solve_add_amount_for_total_risk( + direction, + qty_existing, + entry_existing, + entry_add, + sl, + risk_budget, + contract_size, + ) + if err: + return None, err + q2 = float(q2_raw) + new_qty = qty_existing + q2 + new_avg = avg_entry_after_add(qty_existing, entry_existing, q2, entry_add) + cs = float(contract_size or 1.0) + loss_sl = loss_at_stop_usdt(direction, new_avg, new_qty, sl, cs) + reward_tp = reward_at_tp_usdt(direction, new_avg, initial_take_profit, new_qty, cs) + return { + "symbol": symbol, + "direction": direction, + "add_mode": mode, + "add_mode_label": mode_label(mode), + "add_price": round(entry_add, 10), + "new_stop_loss": round(sl, 10), + "breakthrough_price": float(breakthrough_price) if breakthrough_price not in (None, "") else None, + "initial_take_profit": float(initial_take_profit), + "risk_percent": float(risk_percent), + "risk_budget_usdt": round(risk_budget, 4), + "add_amount_raw": q2, + "qty_existing": float(qty_existing), + "entry_existing": float(entry_existing), + "qty_after": new_qty, + "avg_entry_after": round(new_avg, 10), + "loss_at_sl_usdt": round(loss_sl, 4), + "reward_at_tp_usdt": round(reward_tp, 4), + "legs_done": int(legs_done), + "leg_index_next": int(legs_done) + 1, + "fib_upper": fib_upper, + "fib_lower": fib_lower, + "contract_size": cs, + }, None diff --git a/strategy_roll_monitor_lib.py b/lib/strategy/strategy_roll_monitor_lib.py similarity index 94% rename from strategy_roll_monitor_lib.py rename to lib/strategy/strategy_roll_monitor_lib.py index b502ddb..2990bef 100644 --- a/strategy_roll_monitor_lib.py +++ b/lib/strategy/strategy_roll_monitor_lib.py @@ -1,520 +1,520 @@ -"""滚仓程序监控:斐波/突破触价市价成交、失效、外部平仓同步(各所共用)。""" -from __future__ import annotations - -from typing import Any, Optional - -from strategy_roll_lib import ( - BREAKOUT_MODE, - FIB_MODES, - MARKET_MODE, - mode_label, - roll_breakout_invalidate, - roll_breakout_trigger_crossed, - roll_fib_invalidate, - roll_fib_trigger_crossed, - calc_risk_budget_usdt, - max_roll_legs, - preview_roll, - solve_add_amount_for_total_risk, -) -from strategy_db import init_strategy_tables - -ROLL_LEG_STATUS_LABELS = { - "pending": "监控中", - "filled": "已成交", - "cancelled": "已删除", - "invalidated": "已失效", -} - - -def roll_leg_status_label(status: Optional[str]) -> str: - s = (status or "").strip().lower() - return ROLL_LEG_STATUS_LABELS.get(s, status or "—") - - -def check_roll_monitors(cfg: dict[str, Any]) -> None: - get_db = cfg["get_db"] - conn = get_db() - try: - init_strategy_tables(conn) - _reconcile_roll_groups(conn, cfg) - _check_pending_roll_legs(conn, cfg) - conn.commit() - except Exception: - try: - conn.rollback() - except Exception: - pass - finally: - try: - conn.close() - except Exception: - pass - - -def sync_roll_after_external_close( - cfg: dict, conn, symbol: str, direction: str, *, reason: str = "持仓已平" -) -> dict[str, Any]: - """中控/实例手动平仓后:取消 pending 腿并关闭 active 滚仓组(保留 filled 历史)。""" - norm = cfg.get("normalize_symbol_input") - sym = norm(symbol) if callable(norm) else (symbol or "").strip() - if not sym: - return {"ok": False, "msg": "symbol 无效", "closed_groups": 0, "cancelled_legs": 0} - direction = (direction or "long").strip().lower() - init_strategy_tables(conn) - rows = conn.execute( - """SELECT g.* FROM roll_groups g - WHERE g.status='active' AND g.symbol=? AND g.direction=?""", - (sym, direction), - ).fetchall() - closed = cancelled = 0 - for row in rows: - g = _row_dict(row) - cancelled += _cancel_pending_legs_for_group(conn, cfg, g, status="cancelled") - cur = conn.execute( - "UPDATE roll_groups SET status='closed', updated_at=? WHERE id=? AND status='active'", - (_now(cfg), int(g["id"])), - ) - if getattr(cur, "rowcount", 0): - closed += 1 - try: - from strategy_wechat_notify import notify_roll_group_ended - - notify_roll_group_ended( - cfg, - group_id=int(g["id"]), - symbol=sym, - direction=direction, - reason=reason, - leg_count=int(g.get("leg_count") or 0), - ) - except Exception: - pass - try: - from strategy_snapshot_lib import save_roll_group_snapshot - - save_roll_group_snapshot(cfg, conn, g, result_label="结束") - except Exception: - pass - return { - "ok": True, - "symbol": sym, - "direction": direction, - "closed_groups": closed, - "cancelled_legs": cancelled, - } - - -def cancel_roll_pending_leg(cfg: dict, conn, leg_id: int) -> tuple[bool, str]: - """用户删除 pending 滚仓腿(不可修改,仅删除)。""" - init_strategy_tables(conn) - row = conn.execute( - "SELECT l.*, g.symbol, g.direction, g.status AS group_status FROM roll_legs l " - "INNER JOIN roll_groups g ON g.id = l.roll_group_id WHERE l.id=?", - (int(leg_id),), - ).fetchone() - if not row: - return False, "滚仓腿不存在" - leg = _row_dict(row) - if (leg.get("status") or "").strip().lower() != "pending": - return False, "仅监控中的腿可删除" - _cancel_roll_leg_order(cfg, {"symbol": leg.get("symbol"), "exchange_symbol": leg.get("exchange_symbol")}, leg) - conn.execute( - "UPDATE roll_legs SET status='cancelled' WHERE id=? AND status='pending'", - (int(leg_id),), - ) - conn.commit() - return True, "已删除滚仓监控" - - -def count_filled_roll_legs(conn, roll_group_id: int) -> int: - row = conn.execute( - "SELECT COUNT(*) FROM roll_legs WHERE roll_group_id=? AND status='filled'", - (int(roll_group_id),), - ).fetchone() - return int(row[0] if row else 0) - - -def count_pending_roll_legs(conn, roll_group_id: int) -> int: - row = conn.execute( - "SELECT COUNT(*) FROM roll_legs WHERE roll_group_id=? AND status='pending'", - (int(roll_group_id),), - ).fetchone() - return int(row[0] if row else 0) - - -def _row_dict(row) -> dict: - if row is None: - return {} - try: - return dict(row) - except Exception: - return {} - - -def _now(cfg: dict) -> str: - fn = cfg.get("app_now_str") - return fn() if callable(fn) else "" - - -def _cancel_pending_legs_for_group(conn, cfg: dict, group: dict, *, status: str = "cancelled") -> int: - gid = int(group["id"]) - n = 0 - for leg in conn.execute( - "SELECT * FROM roll_legs WHERE roll_group_id=? AND status='pending'", - (gid,), - ).fetchall(): - ld = _row_dict(leg) - _cancel_roll_leg_order(cfg, group, ld) - conn.execute( - "UPDATE roll_legs SET status=? WHERE id=? AND status='pending'", - (status, ld["id"]), - ) - n += 1 - return n - - -def _close_roll_group(conn, cfg: dict, group: dict, *, reason: str = "下单监控已结案或交易所无同向持仓") -> None: - gid = int(group["id"]) - _cancel_pending_legs_for_group(conn, cfg, group, status="cancelled") - cur = conn.execute( - "UPDATE roll_groups SET status='closed', updated_at=? WHERE id=? AND status='active'", - (_now(cfg), gid), - ) - if getattr(cur, "rowcount", 0): - try: - from strategy_wechat_notify import notify_roll_group_ended - - notify_roll_group_ended( - cfg, - group_id=gid, - symbol=group.get("symbol") or "", - direction=group.get("direction") or "long", - reason=reason, - leg_count=int(group.get("leg_count") or 0), - ) - except Exception: - pass - try: - from strategy_snapshot_lib import save_roll_group_snapshot - - save_roll_group_snapshot(cfg, conn, group, result_label="结束") - except Exception: - pass - - -def _reconcile_roll_groups(conn, cfg: dict) -> None: - rows = conn.execute( - """SELECT g.*, m.status AS monitor_status - FROM roll_groups g - LEFT JOIN order_monitors m ON m.id = g.order_monitor_id - WHERE g.status='active'""" - ).fetchall() - for row in rows: - g = _row_dict(row) - symbol = g.get("symbol") or "" - direction = (g.get("direction") or "long").strip().lower() - ex_sym = g.get("exchange_symbol") or cfg["normalize_exchange_symbol"](symbol) - mon_ok = (row["monitor_status"] or "").strip().lower() == "active" - pos = cfg["get_position"](ex_sym, direction) - qty = float(pos.get("contracts") or 0) - if not mon_ok or qty <= 0: - _close_roll_group(conn, cfg, g) - - -def _cancel_roll_leg_order(cfg: dict, group: dict, leg: dict) -> None: - oid = (leg.get("exchange_order_id") or "").strip() - if not oid: - return - symbol = group.get("symbol") or "" - ex_sym = group.get("exchange_symbol") or cfg["normalize_exchange_symbol"](symbol) - cancel = cfg.get("cancel_limit_order") - if callable(cancel): - try: - cancel(ex_sym, oid) - except Exception: - pass - - -def _contract_size(cfg: dict, ex_sym: str) -> float: - get_cs = cfg.get("get_contract_size") - if callable(get_cs): - try: - return float(get_cs(ex_sym) or 1.0) - except Exception: - pass - return 1.0 - - -def _resolve_add_mode(leg: dict) -> str: - raw = (leg.get("add_mode") or "").strip().lower() - if raw in (MARKET_MODE, "market", "市价", "市价加仓"): - return MARKET_MODE - if "786" in raw or raw == "fib_786": - return "fib_786" - if "618" in raw or raw == "fib_618": - return "fib_618" - if raw in (BREAKOUT_MODE, "突破", "突破加仓"): - return BREAKOUT_MODE - if raw.startswith("fib"): - return raw.replace(".", "_").replace("0.", "0") - return raw or MARKET_MODE - - -def _check_pending_roll_legs(conn, cfg: dict) -> None: - rows = conn.execute( - """SELECT l.*, g.symbol, g.exchange_symbol, g.direction, g.initial_take_profit, - g.order_monitor_id, g.risk_percent, g.leg_count - FROM roll_legs l - INNER JOIN roll_groups g ON g.id = l.roll_group_id AND g.status='active' - WHERE l.status='pending'""" - ).fetchall() - for row in rows: - leg = _row_dict(row) - group = { - "id": leg["roll_group_id"], - "symbol": leg["symbol"], - "exchange_symbol": leg["exchange_symbol"], - "direction": leg["direction"], - "initial_take_profit": leg["initial_take_profit"], - "order_monitor_id": leg["order_monitor_id"], - "risk_percent": leg.get("risk_percent"), - "leg_count": leg.get("leg_count"), - } - _process_pending_roll_leg(conn, cfg, group, leg) - - -def _process_pending_roll_leg(conn, cfg: dict, group: dict, leg: dict) -> None: - symbol = group.get("symbol") or "" - direction = (group.get("direction") or "long").strip().lower() - ex_sym = group.get("exchange_symbol") or cfg["normalize_exchange_symbol"](symbol) - mark_fn = cfg.get("get_mark_price") or cfg.get("get_price") - mark = mark_fn(symbol) if callable(mark_fn) else None - if mark is None: - return - mark_f = float(mark) - prev_mark = leg.get("last_mark_price") - try: - prev_f = float(prev_mark) if prev_mark not in (None, "") else None - except (TypeError, ValueError): - prev_f = None - - mode = _resolve_add_mode(leg) - sl = float(leg.get("new_stop_loss") or 0) - fib_u, fib_l = leg.get("fib_upper"), leg.get("fib_lower") - bp = leg.get("breakthrough_price") - - if mode in FIB_MODES and fib_u is not None and fib_l is not None: - if roll_fib_invalidate(direction, mark_f, float(fib_u), float(fib_l)): - _invalidate_roll_leg(conn, cfg, group, leg, mark_f, reason="止盈侧突破") - return - elif mode == BREAKOUT_MODE and sl > 0: - if roll_breakout_invalidate(direction, mark_f, sl): - _invalidate_roll_leg(conn, cfg, group, leg, mark_f, reason="止损侧突破") - return - - triggered = False - if mode in FIB_MODES: - lp = leg.get("limit_price") - if lp is not None and roll_fib_trigger_crossed(direction, prev_f, mark_f, float(lp)): - triggered = True - elif mode == BREAKOUT_MODE and bp is not None: - if roll_breakout_trigger_crossed(direction, prev_f, mark_f, float(bp)): - triggered = True - - conn.execute( - "UPDATE roll_legs SET last_mark_price=? WHERE id=? AND status='pending'", - (mark_f, int(leg["id"])), - ) - - if triggered: - _execute_pending_roll_leg(conn, cfg, group, leg, ex_sym, direction, mark_f) - return - - -def _execute_pending_roll_leg( - conn, - cfg: dict, - group: dict, - leg: dict, - ex_sym: str, - direction: str, - mark: float, -) -> None: - leg_id = int(leg["id"]) - gid = int(group["roll_group_id"]) if "roll_group_id" in leg else int(group["id"]) - mon_id = group.get("order_monitor_id") - mon = None - if mon_id: - row = conn.execute("SELECT * FROM order_monitors WHERE id=?", (mon_id,)).fetchone() - mon = _row_dict(row) if row else None - if not mon or (mon.get("status") or "").strip().lower() != "active": - _invalidate_roll_leg(conn, cfg, group, leg, mark, reason="监控单已失效") - return - - pos = cfg["get_position"](ex_sym, direction) or {} - qty = float(pos.get("contracts") or 0) - entry = float(pos.get("entry_price") or mon.get("trigger_price") or 0) - if qty <= 0 or entry <= 0: - _invalidate_roll_leg(conn, cfg, group, leg, mark, reason="无持仓") - return - - filled = count_filled_roll_legs(conn, gid) - if filled >= max_roll_legs(direction): - _invalidate_roll_leg(conn, cfg, group, leg, mark, reason="滚仓次数已满") - return - - try: - risk_pct = float(mon.get("risk_percent") or group.get("risk_percent") or 2) - except (TypeError, ValueError): - risk_pct = 2.0 - conn_cap = cfg["get_db"]() - try: - capital = float(cfg["get_trading_capital_usdt"](conn_cap)) - finally: - conn_cap.close() - - cs = _contract_size(cfg, ex_sym) - sl = float(leg.get("new_stop_loss") or 0) - tp0 = float(group.get("initial_take_profit") or mon.get("take_profit") or 0) - mode = _resolve_add_mode(leg) - - q2_raw, err = solve_add_amount_for_total_risk( - direction, qty, entry, mark, sl, calc_risk_budget_usdt(capital, risk_pct), cs - ) - if err or q2_raw is None or float(q2_raw) <= 0: - _invalidate_roll_leg(conn, cfg, group, leg, mark, reason=err or "无法计算加仓张数") - return - - amount = cfg["amount_to_precision"](ex_sym, float(q2_raw)) - if amount is None or float(amount) <= 0: - _invalidate_roll_leg(conn, cfg, group, leg, mark, reason="加仓张数低于交易所最小精度") - return - - lev_fn = cfg.get("default_leverage") - if not callable(lev_fn): - lev_fn = lambda _s: 5 - leverage = int(lev_fn(group.get("symbol") or "")) - - try: - order = cfg["market_add"](ex_sym, direction, float(amount), leverage) - fill = float( - cfg.get("resolve_fill_price", lambda o, s, p: p)(order, ex_sym, mark) or mark - ) - except Exception as e: - fe = cfg.get("friendly_error") - msg = fe(e) if callable(fe) else str(e) - _notify_roll_fail(cfg, group, leg, mark, msg) - return - - oid = str(order.get("id") or "") if isinstance(order, dict) else "" - cfg["replace_tpsl"](ex_sym, direction, sl, tp0, mon) - conn.execute( - """UPDATE roll_legs SET status='filled', fill_price=?, amount=?, exchange_order_id=?, - new_stop_loss=? WHERE id=? AND status='pending'""", - (fill, float(amount), oid, sl, leg_id), - ) - conn.execute( - "UPDATE roll_groups SET leg_count=?, current_stop_loss=?, updated_at=? WHERE id=?", - (filled + 1, sl, _now(cfg), gid), - ) - conn.execute( - "UPDATE order_monitors SET stop_loss=? WHERE id=? AND status='active'", - (sl, mon["id"]), - ) - - notify = cfg.get("send_wechat") - if callable(notify): - sym = group.get("symbol") or "" - mode_lbl = leg.get("add_mode") or mode_label(mode) - fmt = cfg.get("format_price") - px_txt = fmt(sym, fill) if callable(fmt) else str(fill) - sl_txt = fmt(sym, sl) if callable(fmt) else str(sl) - acct = _wechat_account(cfg) - dir_txt = _wechat_dir(cfg, direction) - notify( - f"# ✅ {sym} 滚仓触价成交\n" - f"**账户:{acct}**\n" - f"- 方式:{mode_lbl}|{dir_txt}\n" - f"- 成交价:{px_txt}|张数:{amount}\n" - f"- 新止损:{sl_txt}(止盈仍为首仓)\n" - ) - - -def _invalidate_roll_leg( - conn, - cfg: dict, - group: dict, - leg: dict, - mark: float, - *, - reason: str = "", -) -> None: - leg_id = int(leg["id"]) - cur = conn.execute("SELECT status FROM roll_legs WHERE id=?", (leg_id,)).fetchone() - if not cur or (cur[0] or "").strip().lower() in ("invalidated", "filled", "cancelled"): - return - _cancel_roll_leg_order(cfg, group, leg) - conn.execute( - "UPDATE roll_legs SET status='invalidated' WHERE id=? AND status='pending'", - (leg_id,), - ) - _send_roll_invalidate_wechat(cfg, group, leg, mark, reason=reason) - - -def _notify_roll_fail(cfg: dict, group: dict, leg: dict, mark: float, reason: str) -> None: - notify = cfg.get("send_wechat") - if not callable(notify): - return - sym = group.get("symbol") or "" - mode = leg.get("add_mode") or "滚仓" - acct = _wechat_account(cfg) - notify( - f"# ❌ {sym} 滚仓触价成交失败\n" - f"**账户:{acct}**\n" - f"- 方式:{mode}\n" - f"- 原因:{reason}\n" - ) - - -def _send_roll_invalidate_wechat( - cfg: dict, group: dict, leg: dict, mark: float, *, reason: str = "" -) -> None: - notify = cfg.get("send_wechat") - if not callable(notify): - return - sym = group.get("symbol") or "" - direction = (group.get("direction") or "long").strip().lower() - mode = leg.get("add_mode") or "滚仓监控" - fmt = cfg.get("format_price") - mark_txt = fmt(sym, mark) if callable(fmt) else str(mark) - acct = _wechat_account(cfg) - dir_txt = _wechat_dir(cfg, direction) - detail = reason or "条件不满足" - notify( - f"# ⚠️ {sym} 滚仓监控失效\n" - f"**账户:{acct}**\n" - f"- 方式:{mode}|{dir_txt}\n" - f"- 标记价 {mark_txt}|{detail}\n" - f"- 本条监控已结案,可重新提交\n" - ) - - -def _wechat_account(cfg: dict) -> str: - fn = cfg.get("wechat_account_label") - if callable(fn): - try: - return str(fn()) - except Exception: - pass - return str(cfg.get("exchange_display") or "") - - -def _wechat_dir(cfg: dict, direction: str) -> str: - fn = cfg.get("wechat_direction_text") - if callable(fn): - try: - return str(fn(direction)) - except Exception: - pass - return "做多" if (direction or "long").strip().lower() == "long" else "做空" +"""滚仓程序监控:斐波/突破触价市价成交、失效、外部平仓同步(各所共用)。""" +from __future__ import annotations + +from typing import Any, Optional + +from lib.strategy.strategy_roll_lib import ( + BREAKOUT_MODE, + FIB_MODES, + MARKET_MODE, + mode_label, + roll_breakout_invalidate, + roll_breakout_trigger_crossed, + roll_fib_invalidate, + roll_fib_trigger_crossed, + calc_risk_budget_usdt, + max_roll_legs, + preview_roll, + solve_add_amount_for_total_risk, +) +from lib.strategy.strategy_db import init_strategy_tables + +ROLL_LEG_STATUS_LABELS = { + "pending": "监控中", + "filled": "已成交", + "cancelled": "已删除", + "invalidated": "已失效", +} + + +def roll_leg_status_label(status: Optional[str]) -> str: + s = (status or "").strip().lower() + return ROLL_LEG_STATUS_LABELS.get(s, status or "—") + + +def check_roll_monitors(cfg: dict[str, Any]) -> None: + get_db = cfg["get_db"] + conn = get_db() + try: + init_strategy_tables(conn) + _reconcile_roll_groups(conn, cfg) + _check_pending_roll_legs(conn, cfg) + conn.commit() + except Exception: + try: + conn.rollback() + except Exception: + pass + finally: + try: + conn.close() + except Exception: + pass + + +def sync_roll_after_external_close( + cfg: dict, conn, symbol: str, direction: str, *, reason: str = "持仓已平" +) -> dict[str, Any]: + """中控/实例手动平仓后:取消 pending 腿并关闭 active 滚仓组(保留 filled 历史)。""" + norm = cfg.get("normalize_symbol_input") + sym = norm(symbol) if callable(norm) else (symbol or "").strip() + if not sym: + return {"ok": False, "msg": "symbol 无效", "closed_groups": 0, "cancelled_legs": 0} + direction = (direction or "long").strip().lower() + init_strategy_tables(conn) + rows = conn.execute( + """SELECT g.* FROM roll_groups g + WHERE g.status='active' AND g.symbol=? AND g.direction=?""", + (sym, direction), + ).fetchall() + closed = cancelled = 0 + for row in rows: + g = _row_dict(row) + cancelled += _cancel_pending_legs_for_group(conn, cfg, g, status="cancelled") + cur = conn.execute( + "UPDATE roll_groups SET status='closed', updated_at=? WHERE id=? AND status='active'", + (_now(cfg), int(g["id"])), + ) + if getattr(cur, "rowcount", 0): + closed += 1 + try: + from lib.strategy.strategy_wechat_notify import notify_roll_group_ended + + notify_roll_group_ended( + cfg, + group_id=int(g["id"]), + symbol=sym, + direction=direction, + reason=reason, + leg_count=int(g.get("leg_count") or 0), + ) + except Exception: + pass + try: + from lib.strategy.strategy_snapshot_lib import save_roll_group_snapshot + + save_roll_group_snapshot(cfg, conn, g, result_label="结束") + except Exception: + pass + return { + "ok": True, + "symbol": sym, + "direction": direction, + "closed_groups": closed, + "cancelled_legs": cancelled, + } + + +def cancel_roll_pending_leg(cfg: dict, conn, leg_id: int) -> tuple[bool, str]: + """用户删除 pending 滚仓腿(不可修改,仅删除)。""" + init_strategy_tables(conn) + row = conn.execute( + "SELECT l.*, g.symbol, g.direction, g.status AS group_status FROM roll_legs l " + "INNER JOIN roll_groups g ON g.id = l.roll_group_id WHERE l.id=?", + (int(leg_id),), + ).fetchone() + if not row: + return False, "滚仓腿不存在" + leg = _row_dict(row) + if (leg.get("status") or "").strip().lower() != "pending": + return False, "仅监控中的腿可删除" + _cancel_roll_leg_order(cfg, {"symbol": leg.get("symbol"), "exchange_symbol": leg.get("exchange_symbol")}, leg) + conn.execute( + "UPDATE roll_legs SET status='cancelled' WHERE id=? AND status='pending'", + (int(leg_id),), + ) + conn.commit() + return True, "已删除滚仓监控" + + +def count_filled_roll_legs(conn, roll_group_id: int) -> int: + row = conn.execute( + "SELECT COUNT(*) FROM roll_legs WHERE roll_group_id=? AND status='filled'", + (int(roll_group_id),), + ).fetchone() + return int(row[0] if row else 0) + + +def count_pending_roll_legs(conn, roll_group_id: int) -> int: + row = conn.execute( + "SELECT COUNT(*) FROM roll_legs WHERE roll_group_id=? AND status='pending'", + (int(roll_group_id),), + ).fetchone() + return int(row[0] if row else 0) + + +def _row_dict(row) -> dict: + if row is None: + return {} + try: + return dict(row) + except Exception: + return {} + + +def _now(cfg: dict) -> str: + fn = cfg.get("app_now_str") + return fn() if callable(fn) else "" + + +def _cancel_pending_legs_for_group(conn, cfg: dict, group: dict, *, status: str = "cancelled") -> int: + gid = int(group["id"]) + n = 0 + for leg in conn.execute( + "SELECT * FROM roll_legs WHERE roll_group_id=? AND status='pending'", + (gid,), + ).fetchall(): + ld = _row_dict(leg) + _cancel_roll_leg_order(cfg, group, ld) + conn.execute( + "UPDATE roll_legs SET status=? WHERE id=? AND status='pending'", + (status, ld["id"]), + ) + n += 1 + return n + + +def _close_roll_group(conn, cfg: dict, group: dict, *, reason: str = "下单监控已结案或交易所无同向持仓") -> None: + gid = int(group["id"]) + _cancel_pending_legs_for_group(conn, cfg, group, status="cancelled") + cur = conn.execute( + "UPDATE roll_groups SET status='closed', updated_at=? WHERE id=? AND status='active'", + (_now(cfg), gid), + ) + if getattr(cur, "rowcount", 0): + try: + from lib.strategy.strategy_wechat_notify import notify_roll_group_ended + + notify_roll_group_ended( + cfg, + group_id=gid, + symbol=group.get("symbol") or "", + direction=group.get("direction") or "long", + reason=reason, + leg_count=int(group.get("leg_count") or 0), + ) + except Exception: + pass + try: + from lib.strategy.strategy_snapshot_lib import save_roll_group_snapshot + + save_roll_group_snapshot(cfg, conn, group, result_label="结束") + except Exception: + pass + + +def _reconcile_roll_groups(conn, cfg: dict) -> None: + rows = conn.execute( + """SELECT g.*, m.status AS monitor_status + FROM roll_groups g + LEFT JOIN order_monitors m ON m.id = g.order_monitor_id + WHERE g.status='active'""" + ).fetchall() + for row in rows: + g = _row_dict(row) + symbol = g.get("symbol") or "" + direction = (g.get("direction") or "long").strip().lower() + ex_sym = g.get("exchange_symbol") or cfg["normalize_exchange_symbol"](symbol) + mon_ok = (row["monitor_status"] or "").strip().lower() == "active" + pos = cfg["get_position"](ex_sym, direction) + qty = float(pos.get("contracts") or 0) + if not mon_ok or qty <= 0: + _close_roll_group(conn, cfg, g) + + +def _cancel_roll_leg_order(cfg: dict, group: dict, leg: dict) -> None: + oid = (leg.get("exchange_order_id") or "").strip() + if not oid: + return + symbol = group.get("symbol") or "" + ex_sym = group.get("exchange_symbol") or cfg["normalize_exchange_symbol"](symbol) + cancel = cfg.get("cancel_limit_order") + if callable(cancel): + try: + cancel(ex_sym, oid) + except Exception: + pass + + +def _contract_size(cfg: dict, ex_sym: str) -> float: + get_cs = cfg.get("get_contract_size") + if callable(get_cs): + try: + return float(get_cs(ex_sym) or 1.0) + except Exception: + pass + return 1.0 + + +def _resolve_add_mode(leg: dict) -> str: + raw = (leg.get("add_mode") or "").strip().lower() + if raw in (MARKET_MODE, "market", "市价", "市价加仓"): + return MARKET_MODE + if "786" in raw or raw == "fib_786": + return "fib_786" + if "618" in raw or raw == "fib_618": + return "fib_618" + if raw in (BREAKOUT_MODE, "突破", "突破加仓"): + return BREAKOUT_MODE + if raw.startswith("fib"): + return raw.replace(".", "_").replace("0.", "0") + return raw or MARKET_MODE + + +def _check_pending_roll_legs(conn, cfg: dict) -> None: + rows = conn.execute( + """SELECT l.*, g.symbol, g.exchange_symbol, g.direction, g.initial_take_profit, + g.order_monitor_id, g.risk_percent, g.leg_count + FROM roll_legs l + INNER JOIN roll_groups g ON g.id = l.roll_group_id AND g.status='active' + WHERE l.status='pending'""" + ).fetchall() + for row in rows: + leg = _row_dict(row) + group = { + "id": leg["roll_group_id"], + "symbol": leg["symbol"], + "exchange_symbol": leg["exchange_symbol"], + "direction": leg["direction"], + "initial_take_profit": leg["initial_take_profit"], + "order_monitor_id": leg["order_monitor_id"], + "risk_percent": leg.get("risk_percent"), + "leg_count": leg.get("leg_count"), + } + _process_pending_roll_leg(conn, cfg, group, leg) + + +def _process_pending_roll_leg(conn, cfg: dict, group: dict, leg: dict) -> None: + symbol = group.get("symbol") or "" + direction = (group.get("direction") or "long").strip().lower() + ex_sym = group.get("exchange_symbol") or cfg["normalize_exchange_symbol"](symbol) + mark_fn = cfg.get("get_mark_price") or cfg.get("get_price") + mark = mark_fn(symbol) if callable(mark_fn) else None + if mark is None: + return + mark_f = float(mark) + prev_mark = leg.get("last_mark_price") + try: + prev_f = float(prev_mark) if prev_mark not in (None, "") else None + except (TypeError, ValueError): + prev_f = None + + mode = _resolve_add_mode(leg) + sl = float(leg.get("new_stop_loss") or 0) + fib_u, fib_l = leg.get("fib_upper"), leg.get("fib_lower") + bp = leg.get("breakthrough_price") + + if mode in FIB_MODES and fib_u is not None and fib_l is not None: + if roll_fib_invalidate(direction, mark_f, float(fib_u), float(fib_l)): + _invalidate_roll_leg(conn, cfg, group, leg, mark_f, reason="止盈侧突破") + return + elif mode == BREAKOUT_MODE and sl > 0: + if roll_breakout_invalidate(direction, mark_f, sl): + _invalidate_roll_leg(conn, cfg, group, leg, mark_f, reason="止损侧突破") + return + + triggered = False + if mode in FIB_MODES: + lp = leg.get("limit_price") + if lp is not None and roll_fib_trigger_crossed(direction, prev_f, mark_f, float(lp)): + triggered = True + elif mode == BREAKOUT_MODE and bp is not None: + if roll_breakout_trigger_crossed(direction, prev_f, mark_f, float(bp)): + triggered = True + + conn.execute( + "UPDATE roll_legs SET last_mark_price=? WHERE id=? AND status='pending'", + (mark_f, int(leg["id"])), + ) + + if triggered: + _execute_pending_roll_leg(conn, cfg, group, leg, ex_sym, direction, mark_f) + return + + +def _execute_pending_roll_leg( + conn, + cfg: dict, + group: dict, + leg: dict, + ex_sym: str, + direction: str, + mark: float, +) -> None: + leg_id = int(leg["id"]) + gid = int(group["roll_group_id"]) if "roll_group_id" in leg else int(group["id"]) + mon_id = group.get("order_monitor_id") + mon = None + if mon_id: + row = conn.execute("SELECT * FROM order_monitors WHERE id=?", (mon_id,)).fetchone() + mon = _row_dict(row) if row else None + if not mon or (mon.get("status") or "").strip().lower() != "active": + _invalidate_roll_leg(conn, cfg, group, leg, mark, reason="监控单已失效") + return + + pos = cfg["get_position"](ex_sym, direction) or {} + qty = float(pos.get("contracts") or 0) + entry = float(pos.get("entry_price") or mon.get("trigger_price") or 0) + if qty <= 0 or entry <= 0: + _invalidate_roll_leg(conn, cfg, group, leg, mark, reason="无持仓") + return + + filled = count_filled_roll_legs(conn, gid) + if filled >= max_roll_legs(direction): + _invalidate_roll_leg(conn, cfg, group, leg, mark, reason="滚仓次数已满") + return + + try: + risk_pct = float(mon.get("risk_percent") or group.get("risk_percent") or 2) + except (TypeError, ValueError): + risk_pct = 2.0 + conn_cap = cfg["get_db"]() + try: + capital = float(cfg["get_trading_capital_usdt"](conn_cap)) + finally: + conn_cap.close() + + cs = _contract_size(cfg, ex_sym) + sl = float(leg.get("new_stop_loss") or 0) + tp0 = float(group.get("initial_take_profit") or mon.get("take_profit") or 0) + mode = _resolve_add_mode(leg) + + q2_raw, err = solve_add_amount_for_total_risk( + direction, qty, entry, mark, sl, calc_risk_budget_usdt(capital, risk_pct), cs + ) + if err or q2_raw is None or float(q2_raw) <= 0: + _invalidate_roll_leg(conn, cfg, group, leg, mark, reason=err or "无法计算加仓张数") + return + + amount = cfg["amount_to_precision"](ex_sym, float(q2_raw)) + if amount is None or float(amount) <= 0: + _invalidate_roll_leg(conn, cfg, group, leg, mark, reason="加仓张数低于交易所最小精度") + return + + lev_fn = cfg.get("default_leverage") + if not callable(lev_fn): + lev_fn = lambda _s: 5 + leverage = int(lev_fn(group.get("symbol") or "")) + + try: + order = cfg["market_add"](ex_sym, direction, float(amount), leverage) + fill = float( + cfg.get("resolve_fill_price", lambda o, s, p: p)(order, ex_sym, mark) or mark + ) + except Exception as e: + fe = cfg.get("friendly_error") + msg = fe(e) if callable(fe) else str(e) + _notify_roll_fail(cfg, group, leg, mark, msg) + return + + oid = str(order.get("id") or "") if isinstance(order, dict) else "" + cfg["replace_tpsl"](ex_sym, direction, sl, tp0, mon) + conn.execute( + """UPDATE roll_legs SET status='filled', fill_price=?, amount=?, exchange_order_id=?, + new_stop_loss=? WHERE id=? AND status='pending'""", + (fill, float(amount), oid, sl, leg_id), + ) + conn.execute( + "UPDATE roll_groups SET leg_count=?, current_stop_loss=?, updated_at=? WHERE id=?", + (filled + 1, sl, _now(cfg), gid), + ) + conn.execute( + "UPDATE order_monitors SET stop_loss=? WHERE id=? AND status='active'", + (sl, mon["id"]), + ) + + notify = cfg.get("send_wechat") + if callable(notify): + sym = group.get("symbol") or "" + mode_lbl = leg.get("add_mode") or mode_label(mode) + fmt = cfg.get("format_price") + px_txt = fmt(sym, fill) if callable(fmt) else str(fill) + sl_txt = fmt(sym, sl) if callable(fmt) else str(sl) + acct = _wechat_account(cfg) + dir_txt = _wechat_dir(cfg, direction) + notify( + f"# ✅ {sym} 滚仓触价成交\n" + f"**账户:{acct}**\n" + f"- 方式:{mode_lbl}|{dir_txt}\n" + f"- 成交价:{px_txt}|张数:{amount}\n" + f"- 新止损:{sl_txt}(止盈仍为首仓)\n" + ) + + +def _invalidate_roll_leg( + conn, + cfg: dict, + group: dict, + leg: dict, + mark: float, + *, + reason: str = "", +) -> None: + leg_id = int(leg["id"]) + cur = conn.execute("SELECT status FROM roll_legs WHERE id=?", (leg_id,)).fetchone() + if not cur or (cur[0] or "").strip().lower() in ("invalidated", "filled", "cancelled"): + return + _cancel_roll_leg_order(cfg, group, leg) + conn.execute( + "UPDATE roll_legs SET status='invalidated' WHERE id=? AND status='pending'", + (leg_id,), + ) + _send_roll_invalidate_wechat(cfg, group, leg, mark, reason=reason) + + +def _notify_roll_fail(cfg: dict, group: dict, leg: dict, mark: float, reason: str) -> None: + notify = cfg.get("send_wechat") + if not callable(notify): + return + sym = group.get("symbol") or "" + mode = leg.get("add_mode") or "滚仓" + acct = _wechat_account(cfg) + notify( + f"# ❌ {sym} 滚仓触价成交失败\n" + f"**账户:{acct}**\n" + f"- 方式:{mode}\n" + f"- 原因:{reason}\n" + ) + + +def _send_roll_invalidate_wechat( + cfg: dict, group: dict, leg: dict, mark: float, *, reason: str = "" +) -> None: + notify = cfg.get("send_wechat") + if not callable(notify): + return + sym = group.get("symbol") or "" + direction = (group.get("direction") or "long").strip().lower() + mode = leg.get("add_mode") or "滚仓监控" + fmt = cfg.get("format_price") + mark_txt = fmt(sym, mark) if callable(fmt) else str(mark) + acct = _wechat_account(cfg) + dir_txt = _wechat_dir(cfg, direction) + detail = reason or "条件不满足" + notify( + f"# ⚠️ {sym} 滚仓监控失效\n" + f"**账户:{acct}**\n" + f"- 方式:{mode}|{dir_txt}\n" + f"- 标记价 {mark_txt}|{detail}\n" + f"- 本条监控已结案,可重新提交\n" + ) + + +def _wechat_account(cfg: dict) -> str: + fn = cfg.get("wechat_account_label") + if callable(fn): + try: + return str(fn()) + except Exception: + pass + return str(cfg.get("exchange_display") or "") + + +def _wechat_dir(cfg: dict, direction: str) -> str: + fn = cfg.get("wechat_direction_text") + if callable(fn): + try: + return str(fn(direction)) + except Exception: + pass + return "做多" if (direction or "long").strip().lower() == "long" else "做空" diff --git a/strategy_roll_ui_lib.py b/lib/strategy/strategy_roll_ui_lib.py similarity index 100% rename from strategy_roll_ui_lib.py rename to lib/strategy/strategy_roll_ui_lib.py diff --git a/strategy_snapshot_lib.py b/lib/strategy/strategy_snapshot_lib.py similarity index 96% rename from strategy_snapshot_lib.py rename to lib/strategy/strategy_snapshot_lib.py index caeb141..a55c4e6 100644 --- a/strategy_snapshot_lib.py +++ b/lib/strategy/strategy_snapshot_lib.py @@ -1,529 +1,529 @@ -"""策略结束快照:趋势回调 / 顺势加仓(四所共用)。""" -from __future__ import annotations - -import json -from datetime import datetime, timezone -from typing import Any, Callable, Optional - -STRATEGY_TREND = "trend_pullback" -STRATEGY_ROLL = "roll" -STRATEGY_SNAPSHOTS_MAX_ROWS = 100 -# 同一趋势计划只允许一条「结束类」快照(中控全平 + 监控止损 + 实例结束计划) -FINAL_TREND_CLOSE_RANK = { - "手动平仓": 3, - "止盈": 2, - "止损": 1, -} -FINAL_TREND_CLOSE_LABELS = tuple(FINAL_TREND_CLOSE_RANK.keys()) - -STRATEGY_SNAPSHOTS_SQL = """ -CREATE TABLE IF NOT EXISTS strategy_trade_snapshots ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - strategy_type TEXT NOT NULL, - source_id INTEGER, - symbol TEXT, - exchange_symbol TEXT, - direction TEXT, - result_label TEXT, - status_at_close TEXT, - opened_at TEXT, - closed_at TEXT, - pnl_amount REAL, - snapshot_json TEXT NOT NULL, - created_at TEXT -) -""" - - -def init_strategy_snapshot_table(conn) -> None: - conn.execute(STRATEGY_SNAPSHOTS_SQL) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_strategy_snapshots_closed " - "ON strategy_trade_snapshots(closed_at DESC)" - ) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_strategy_snapshots_type " - "ON strategy_trade_snapshots(strategy_type, source_id)" - ) - - -def _row_dict(row) -> dict: - if row is None: - return {} - try: - return dict(row) - except Exception: - return {} - - -def _json_dumps(obj: Any) -> str: - return json.dumps(obj, ensure_ascii=False, separators=(",", ":")) - - -def build_trend_dca_levels(plan: dict) -> list[dict]: - """首仓 + 补仓档位列表(供策略页 / 中控)。""" - out: list[dict] = [] - p = plan or {} - try: - legs_done = int(p.get("legs_done") or 0) - except (TypeError, ValueError): - legs_done = 0 - try: - dca_legs = int(p.get("dca_legs") or 0) - except (TypeError, ValueError): - dca_legs = 0 - first_done = int(p.get("first_order_done") or 0) != 0 - try: - grid = json.loads(p.get("grid_prices_json") or "[]") - if not isinstance(grid, list): - grid = [] - except Exception: - grid = [] - try: - leg_amounts = json.loads(p.get("leg_amounts_json") or "[]") - if not isinstance(leg_amounts, list): - leg_amounts = [] - except Exception: - leg_amounts = [] - - out.append( - { - "i": 0, - "leg_key": "first", - "label": "首仓", - "price": None, - "contracts": p.get("first_order_amount"), - "status": "done" if first_done else "pending", - "status_label": "已开仓" if first_done else "待开仓", - } - ) - n = max(len(grid), len(leg_amounts), dca_legs) - for idx in range(n): - leg_i = idx + 1 - price = grid[idx] if idx < len(grid) else None - contracts = leg_amounts[idx] if idx < len(leg_amounts) else None - done = leg_i <= legs_done - out.append( - { - "i": leg_i, - "leg_key": f"dca_{leg_i}", - "label": f"补仓{leg_i}", - "price": price, - "contracts": contracts, - "status": "done" if done else "pending", - "status_label": "已补仓" if done else "待补仓", - } - ) - return out - - -def attach_trend_dca_levels(plan: dict) -> dict: - from strategy_trend_lib import enrich_trend_dca_levels_with_tp - - d = dict(plan or {}) - levels = build_trend_dca_levels(d) - d["dca_levels"] = enrich_trend_dca_levels_with_tp(d, levels) - return d - - -def _snapshot_key_exists( - conn, strategy_type: str, source_id: int, result_label: str -) -> bool: - if source_id <= 0: - return False - label = (result_label or "").strip() - row = conn.execute( - """SELECT 1 FROM strategy_trade_snapshots - WHERE strategy_type=? AND source_id=? AND result_label=? - LIMIT 1""", - (strategy_type, int(source_id), label), - ).fetchone() - return row is not None - - -def _final_trend_close_rank(result_label: str) -> int: - return int(FINAL_TREND_CLOSE_RANK.get((result_label or "").strip(), 0)) - - -def _purge_weaker_trend_final_snapshots( - conn, plan_id: int, result_label: str -) -> None: - """写入更高优先级结束快照时,删除同计划较弱的结束记录。""" - rank = _final_trend_close_rank(result_label) - if rank <= 0 or plan_id <= 0: - return - for label, lr in FINAL_TREND_CLOSE_RANK.items(): - if lr < rank: - conn.execute( - """DELETE FROM strategy_trade_snapshots - WHERE strategy_type=? AND source_id=? AND result_label=?""", - (STRATEGY_TREND, int(plan_id), label), - ) - - -def dedupe_strategy_snapshots(conn) -> int: - """删除重复快照:同结果去重 + 同计划仅保留最高优先级结束类记录。""" - init_strategy_snapshot_table(conn) - removed = 0 - cur = conn.execute( - """DELETE FROM strategy_trade_snapshots - WHERE id IN ( - SELECT s1.id FROM strategy_trade_snapshots s1 - INNER JOIN strategy_trade_snapshots s2 - ON s1.strategy_type = s2.strategy_type - AND s1.source_id = s2.source_id - AND s1.result_label = s2.result_label - AND s1.id < s2.id - )""" - ) - removed += int(getattr(cur, "rowcount", 0) or 0) - rows = conn.execute( - f"""SELECT id, source_id, result_label FROM strategy_trade_snapshots - WHERE strategy_type=? AND result_label IN ({",".join("?" * len(FINAL_TREND_CLOSE_LABELS))})""", - (STRATEGY_TREND, *FINAL_TREND_CLOSE_LABELS), - ).fetchall() - by_plan: dict[int, list] = {} - for row in rows: - d = _row_dict(row) - try: - pid = int(d.get("source_id") or 0) - except (TypeError, ValueError): - pid = 0 - if pid <= 0: - continue - by_plan.setdefault(pid, []).append(d) - drop_ids: list[int] = [] - for snaps in by_plan.values(): - if len(snaps) <= 1: - continue - best = max( - snaps, - key=lambda s: ( - _final_trend_close_rank(str(s.get("result_label") or "")), - int(s.get("id") or 0), - ), - ) - keep_id = int(best.get("id") or 0) - for s in snaps: - sid = int(s.get("id") or 0) - if sid and sid != keep_id: - drop_ids.append(sid) - if drop_ids: - placeholders = ",".join("?" * len(drop_ids)) - cur2 = conn.execute( - f"DELETE FROM strategy_trade_snapshots WHERE id IN ({placeholders})", - drop_ids, - ) - removed += int(getattr(cur2, "rowcount", 0) or 0) - return removed - - -def save_trend_plan_snapshot( - cfg: dict, - conn, - plan_row: Any, - *, - result_label: str, - exit_price: float | None = None, - pnl_amount: float | None = None, - closed_at: str | None = None, -) -> None: - init_strategy_snapshot_table(conn) - row = _row_dict(plan_row) - plan_id = int(row.get("id") or 0) - if plan_id <= 0: - return - label = (result_label or "").strip() - close_rank = _final_trend_close_rank(label) - if close_rank > 0: - existing = conn.execute( - f"""SELECT result_label FROM strategy_trade_snapshots - WHERE strategy_type=? AND source_id=? AND result_label IN ({",".join("?" * len(FINAL_TREND_CLOSE_LABELS))})""", - (STRATEGY_TREND, plan_id, *FINAL_TREND_CLOSE_LABELS), - ).fetchall() - for ex in existing: - ex_label = str(_row_dict(ex).get("result_label") or "") - if _final_trend_close_rank(ex_label) >= close_rank: - return - _purge_weaker_trend_final_snapshots(conn, plan_id, label) - elif _snapshot_key_exists(conn, STRATEGY_TREND, plan_id, label): - return - m = cfg.get("app_module") - close_ts = (closed_at or "").strip() or ( - m.app_now_str() - if m is not None and hasattr(m, "app_now_str") - else datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") - ) - payload = attach_trend_dca_levels(row) - payload["result_label"] = result_label - payload["exit_price"] = exit_price - payload["pnl_amount"] = pnl_amount - payload["status_at_close"] = row.get("status") - conn.execute( - """INSERT INTO strategy_trade_snapshots ( - strategy_type, source_id, symbol, exchange_symbol, direction, - result_label, status_at_close, opened_at, closed_at, pnl_amount, snapshot_json, created_at - ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?)""", - ( - STRATEGY_TREND, - plan_id, - row.get("symbol"), - row.get("exchange_symbol"), - row.get("direction"), - result_label, - row.get("status"), - row.get("opened_at"), - close_ts, - pnl_amount, - _json_dumps(payload), - close_ts, - ), - ) - prune_strategy_snapshots(conn, keep=STRATEGY_SNAPSHOTS_MAX_ROWS) - - -def save_roll_group_snapshot( - cfg: dict, - conn, - group: dict, - *, - result_label: str = "结束", - pnl_amount: float | None = None, -) -> None: - init_strategy_snapshot_table(conn) - g = dict(group or {}) - gid = int(g.get("id") or 0) - if gid <= 0: - return - label = (result_label or "结束").strip() - if _snapshot_key_exists(conn, STRATEGY_ROLL, gid, label): - return - legs = [] - for leg in conn.execute( - "SELECT * FROM roll_legs WHERE roll_group_id=? ORDER BY leg_index ASC, id ASC", - (gid,), - ).fetchall(): - ld = _row_dict(leg) - try: - from strategy_roll_monitor_lib import roll_leg_status_label - - ld["status_label"] = roll_leg_status_label(ld.get("status")) - except Exception: - ld["status_label"] = ld.get("status") or "" - legs.append(ld) - m = cfg.get("app_module") - closed_at = ( - m.app_now_str() - if m is not None and hasattr(m, "app_now_str") - else datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") - ) - payload = { - "group": g, - "legs": legs, - "result_label": result_label, - "pnl_amount": pnl_amount, - } - conn.execute( - """INSERT INTO strategy_trade_snapshots ( - strategy_type, source_id, symbol, exchange_symbol, direction, - result_label, status_at_close, opened_at, closed_at, pnl_amount, snapshot_json, created_at - ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?)""", - ( - STRATEGY_ROLL, - gid, - g.get("symbol"), - g.get("exchange_symbol"), - g.get("direction"), - result_label, - g.get("status"), - g.get("created_at"), - closed_at, - pnl_amount, - _json_dumps(payload), - closed_at, - ), - ) - prune_strategy_snapshots(conn, keep=STRATEGY_SNAPSHOTS_MAX_ROWS) - - -def prune_strategy_snapshots(conn, *, keep: int = STRATEGY_SNAPSHOTS_MAX_ROWS) -> None: - """仅保留最近 keep 条策略快照(按 closed_at / id 倒序)。""" - dedupe_strategy_snapshots(conn) - k = max(1, min(int(keep), 500)) - conn.execute( - """DELETE FROM strategy_trade_snapshots - WHERE id NOT IN ( - SELECT id FROM strategy_trade_snapshots - ORDER BY COALESCE(closed_at, created_at, '') DESC, id DESC - LIMIT ? - )""", - (k,), - ) - - -def _snapshot_pnl(row: dict, snap: dict) -> float | None: - for key in ("pnl_amount",): - v = row.get(key) - if v is not None and v != "": - try: - return float(v) - except (TypeError, ValueError): - pass - v = snap.get("pnl_amount") - if v is not None and v != "": - try: - return float(v) - except (TypeError, ValueError): - pass - return None - - -def _trend_dca_stats(snap: dict) -> dict: - levels = snap.get("dca_levels") or build_trend_dca_levels(snap) - dca_only = [ - lv - for lv in levels - if (lv.get("leg_key") or "") != "first" and (lv.get("label") or "") != "首仓" - ] - done = sum(1 for lv in dca_only if lv.get("status") == "done") - total = len(dca_only) - pending = total - done - if total <= 0: - tag = "na" - elif done <= 0: - tag = "no_dca" - elif done >= total: - tag = "dca_done" - else: - tag = "dca_partial" - return { - "dca_done": done, - "dca_total": total, - "dca_pending": pending, - "dca_tag": tag, - } - - -def _roll_leg_stats(snap: dict) -> dict: - legs = snap.get("legs") or [] - if not isinstance(legs, list): - legs = [] - filled = sum(1 for lg in legs if (lg.get("status") or "").lower() == "filled") - total = len(legs) - pending = total - filled - if total <= 0: - tag = "na" - elif filled <= 0: - tag = "no_dca" - elif filled >= total: - tag = "dca_done" - else: - tag = "dca_partial" - return { - "dca_done": filled, - "dca_total": total, - "dca_pending": pending, - "dca_tag": tag, - } - - -def enrich_strategy_snapshot_row(row: dict) -> dict: - d = dict(row or {}) - snap = d.get("snapshot") or {} - st = (d.get("strategy_type") or "").strip() - pnl = _snapshot_pnl(d, snap) - if pnl is not None: - if pnl > 1e-9: - d["filter_pnl"] = "profit" - elif pnl < -1e-9: - d["filter_pnl"] = "loss" - else: - d["filter_pnl"] = "flat" - else: - d["filter_pnl"] = "unknown" - snap_sym = "" - if isinstance(snap, dict): - snap_sym = (snap.get("symbol") or snap.get("exchange_symbol") or "").strip() - sym = (d.get("symbol") or d.get("exchange_symbol") or snap_sym or "").strip() - if sym: - d["symbol"] = d.get("symbol") or sym - d["exchange_symbol"] = d.get("exchange_symbol") or sym - d["filter_symbol"] = sym.upper().split("/")[0].split(":")[0] if sym else "" - closed = (d.get("closed_at") or d.get("created_at") or "").strip() - d["sort_ts"] = closed - if st == STRATEGY_TREND: - stats = _trend_dca_stats(snap) - d.update(stats) - legs_txt = ( - f"{stats['dca_done']}/{stats['dca_total']}" - if stats["dca_total"] > 0 - else "0/0" - ) - d["summary_dca"] = legs_txt - else: - stats = _roll_leg_stats(snap) - d.update(stats) - d["summary_dca"] = ( - f"{stats['dca_done']}/{stats['dca_total']}腿" - if stats["dca_total"] > 0 - else "—" - ) - return d - - -def list_strategy_snapshots(conn, *, limit: int = 200) -> list[dict]: - init_strategy_snapshot_table(conn) - rows = conn.execute( - "SELECT * FROM strategy_trade_snapshots ORDER BY id DESC LIMIT ?", - (max(1, min(int(limit), 500)),), - ).fetchall() - out = [] - seen: dict[tuple[str, int, str], int] = {} - for r in rows: - d = _row_dict(r) - try: - d["snapshot"] = json.loads(d.get("snapshot_json") or "{}") - except Exception: - d["snapshot"] = {} - st = (d.get("strategy_type") or "").strip() - d["strategy_label"] = "趋势回调" if st == STRATEGY_TREND else "顺势加仓" - enriched = enrich_strategy_snapshot_row(d) - try: - source_id = int(enriched.get("source_id") or 0) - except (TypeError, ValueError): - source_id = 0 - result_label = (enriched.get("result_label") or "").strip() - close_rank = _final_trend_close_rank(result_label) - if st == STRATEGY_TREND and source_id > 0 and close_rank > 0: - plan_key = (st, source_id) - snap_id = int(enriched.get("id") or 0) - prev = seen.get(plan_key) - if prev is not None: - prev_id, prev_rank = prev - if prev_rank > close_rank or (prev_rank == close_rank and prev_id >= snap_id): - continue - out = [x for x in out if int(x.get("id") or 0) != prev_id] - seen[plan_key] = (snap_id, close_rank) - out.append(enriched) - continue - key = (st, source_id, result_label) - snap_id = int(enriched.get("id") or 0) - prev = seen.get(key) - if prev is not None and prev[0] >= snap_id: - continue - if prev is not None: - out = [x for x in out if int(x.get("id") or 0) != prev[0]] - seen[key] = (snap_id, 0) - out.append(enriched) - return out - - -def list_strategy_snapshots_split( - conn, *, limit: int = STRATEGY_SNAPSHOTS_MAX_ROWS -) -> tuple[list[dict], list[dict], list[str]]: - """趋势 / 顺势分组,及筛选用币种列表。""" - all_rows = list_strategy_snapshots(conn, limit=limit) - trend = [r for r in all_rows if (r.get("strategy_type") or "") == STRATEGY_TREND] - roll = [r for r in all_rows if (r.get("strategy_type") or "") == STRATEGY_ROLL] - symbols = sorted({r.get("filter_symbol") or "" for r in all_rows if r.get("filter_symbol")}) - return trend, roll, symbols +"""策略结束快照:趋势回调 / 顺势加仓(四所共用)。""" +from __future__ import annotations + +import json +from datetime import datetime, timezone +from typing import Any, Callable, Optional + +STRATEGY_TREND = "trend_pullback" +STRATEGY_ROLL = "roll" +STRATEGY_SNAPSHOTS_MAX_ROWS = 100 +# 同一趋势计划只允许一条「结束类」快照(中控全平 + 监控止损 + 实例结束计划) +FINAL_TREND_CLOSE_RANK = { + "手动平仓": 3, + "止盈": 2, + "止损": 1, +} +FINAL_TREND_CLOSE_LABELS = tuple(FINAL_TREND_CLOSE_RANK.keys()) + +STRATEGY_SNAPSHOTS_SQL = """ +CREATE TABLE IF NOT EXISTS strategy_trade_snapshots ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + strategy_type TEXT NOT NULL, + source_id INTEGER, + symbol TEXT, + exchange_symbol TEXT, + direction TEXT, + result_label TEXT, + status_at_close TEXT, + opened_at TEXT, + closed_at TEXT, + pnl_amount REAL, + snapshot_json TEXT NOT NULL, + created_at TEXT +) +""" + + +def init_strategy_snapshot_table(conn) -> None: + conn.execute(STRATEGY_SNAPSHOTS_SQL) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_strategy_snapshots_closed " + "ON strategy_trade_snapshots(closed_at DESC)" + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_strategy_snapshots_type " + "ON strategy_trade_snapshots(strategy_type, source_id)" + ) + + +def _row_dict(row) -> dict: + if row is None: + return {} + try: + return dict(row) + except Exception: + return {} + + +def _json_dumps(obj: Any) -> str: + return json.dumps(obj, ensure_ascii=False, separators=(",", ":")) + + +def build_trend_dca_levels(plan: dict) -> list[dict]: + """首仓 + 补仓档位列表(供策略页 / 中控)。""" + out: list[dict] = [] + p = plan or {} + try: + legs_done = int(p.get("legs_done") or 0) + except (TypeError, ValueError): + legs_done = 0 + try: + dca_legs = int(p.get("dca_legs") or 0) + except (TypeError, ValueError): + dca_legs = 0 + first_done = int(p.get("first_order_done") or 0) != 0 + try: + grid = json.loads(p.get("grid_prices_json") or "[]") + if not isinstance(grid, list): + grid = [] + except Exception: + grid = [] + try: + leg_amounts = json.loads(p.get("leg_amounts_json") or "[]") + if not isinstance(leg_amounts, list): + leg_amounts = [] + except Exception: + leg_amounts = [] + + out.append( + { + "i": 0, + "leg_key": "first", + "label": "首仓", + "price": None, + "contracts": p.get("first_order_amount"), + "status": "done" if first_done else "pending", + "status_label": "已开仓" if first_done else "待开仓", + } + ) + n = max(len(grid), len(leg_amounts), dca_legs) + for idx in range(n): + leg_i = idx + 1 + price = grid[idx] if idx < len(grid) else None + contracts = leg_amounts[idx] if idx < len(leg_amounts) else None + done = leg_i <= legs_done + out.append( + { + "i": leg_i, + "leg_key": f"dca_{leg_i}", + "label": f"补仓{leg_i}", + "price": price, + "contracts": contracts, + "status": "done" if done else "pending", + "status_label": "已补仓" if done else "待补仓", + } + ) + return out + + +def attach_trend_dca_levels(plan: dict) -> dict: + from lib.strategy.strategy_trend_lib import enrich_trend_dca_levels_with_tp + + d = dict(plan or {}) + levels = build_trend_dca_levels(d) + d["dca_levels"] = enrich_trend_dca_levels_with_tp(d, levels) + return d + + +def _snapshot_key_exists( + conn, strategy_type: str, source_id: int, result_label: str +) -> bool: + if source_id <= 0: + return False + label = (result_label or "").strip() + row = conn.execute( + """SELECT 1 FROM strategy_trade_snapshots + WHERE strategy_type=? AND source_id=? AND result_label=? + LIMIT 1""", + (strategy_type, int(source_id), label), + ).fetchone() + return row is not None + + +def _final_trend_close_rank(result_label: str) -> int: + return int(FINAL_TREND_CLOSE_RANK.get((result_label or "").strip(), 0)) + + +def _purge_weaker_trend_final_snapshots( + conn, plan_id: int, result_label: str +) -> None: + """写入更高优先级结束快照时,删除同计划较弱的结束记录。""" + rank = _final_trend_close_rank(result_label) + if rank <= 0 or plan_id <= 0: + return + for label, lr in FINAL_TREND_CLOSE_RANK.items(): + if lr < rank: + conn.execute( + """DELETE FROM strategy_trade_snapshots + WHERE strategy_type=? AND source_id=? AND result_label=?""", + (STRATEGY_TREND, int(plan_id), label), + ) + + +def dedupe_strategy_snapshots(conn) -> int: + """删除重复快照:同结果去重 + 同计划仅保留最高优先级结束类记录。""" + init_strategy_snapshot_table(conn) + removed = 0 + cur = conn.execute( + """DELETE FROM strategy_trade_snapshots + WHERE id IN ( + SELECT s1.id FROM strategy_trade_snapshots s1 + INNER JOIN strategy_trade_snapshots s2 + ON s1.strategy_type = s2.strategy_type + AND s1.source_id = s2.source_id + AND s1.result_label = s2.result_label + AND s1.id < s2.id + )""" + ) + removed += int(getattr(cur, "rowcount", 0) or 0) + rows = conn.execute( + f"""SELECT id, source_id, result_label FROM strategy_trade_snapshots + WHERE strategy_type=? AND result_label IN ({",".join("?" * len(FINAL_TREND_CLOSE_LABELS))})""", + (STRATEGY_TREND, *FINAL_TREND_CLOSE_LABELS), + ).fetchall() + by_plan: dict[int, list] = {} + for row in rows: + d = _row_dict(row) + try: + pid = int(d.get("source_id") or 0) + except (TypeError, ValueError): + pid = 0 + if pid <= 0: + continue + by_plan.setdefault(pid, []).append(d) + drop_ids: list[int] = [] + for snaps in by_plan.values(): + if len(snaps) <= 1: + continue + best = max( + snaps, + key=lambda s: ( + _final_trend_close_rank(str(s.get("result_label") or "")), + int(s.get("id") or 0), + ), + ) + keep_id = int(best.get("id") or 0) + for s in snaps: + sid = int(s.get("id") or 0) + if sid and sid != keep_id: + drop_ids.append(sid) + if drop_ids: + placeholders = ",".join("?" * len(drop_ids)) + cur2 = conn.execute( + f"DELETE FROM strategy_trade_snapshots WHERE id IN ({placeholders})", + drop_ids, + ) + removed += int(getattr(cur2, "rowcount", 0) or 0) + return removed + + +def save_trend_plan_snapshot( + cfg: dict, + conn, + plan_row: Any, + *, + result_label: str, + exit_price: float | None = None, + pnl_amount: float | None = None, + closed_at: str | None = None, +) -> None: + init_strategy_snapshot_table(conn) + row = _row_dict(plan_row) + plan_id = int(row.get("id") or 0) + if plan_id <= 0: + return + label = (result_label or "").strip() + close_rank = _final_trend_close_rank(label) + if close_rank > 0: + existing = conn.execute( + f"""SELECT result_label FROM strategy_trade_snapshots + WHERE strategy_type=? AND source_id=? AND result_label IN ({",".join("?" * len(FINAL_TREND_CLOSE_LABELS))})""", + (STRATEGY_TREND, plan_id, *FINAL_TREND_CLOSE_LABELS), + ).fetchall() + for ex in existing: + ex_label = str(_row_dict(ex).get("result_label") or "") + if _final_trend_close_rank(ex_label) >= close_rank: + return + _purge_weaker_trend_final_snapshots(conn, plan_id, label) + elif _snapshot_key_exists(conn, STRATEGY_TREND, plan_id, label): + return + m = cfg.get("app_module") + close_ts = (closed_at or "").strip() or ( + m.app_now_str() + if m is not None and hasattr(m, "app_now_str") + else datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") + ) + payload = attach_trend_dca_levels(row) + payload["result_label"] = result_label + payload["exit_price"] = exit_price + payload["pnl_amount"] = pnl_amount + payload["status_at_close"] = row.get("status") + conn.execute( + """INSERT INTO strategy_trade_snapshots ( + strategy_type, source_id, symbol, exchange_symbol, direction, + result_label, status_at_close, opened_at, closed_at, pnl_amount, snapshot_json, created_at + ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?)""", + ( + STRATEGY_TREND, + plan_id, + row.get("symbol"), + row.get("exchange_symbol"), + row.get("direction"), + result_label, + row.get("status"), + row.get("opened_at"), + close_ts, + pnl_amount, + _json_dumps(payload), + close_ts, + ), + ) + prune_strategy_snapshots(conn, keep=STRATEGY_SNAPSHOTS_MAX_ROWS) + + +def save_roll_group_snapshot( + cfg: dict, + conn, + group: dict, + *, + result_label: str = "结束", + pnl_amount: float | None = None, +) -> None: + init_strategy_snapshot_table(conn) + g = dict(group or {}) + gid = int(g.get("id") or 0) + if gid <= 0: + return + label = (result_label or "结束").strip() + if _snapshot_key_exists(conn, STRATEGY_ROLL, gid, label): + return + legs = [] + for leg in conn.execute( + "SELECT * FROM roll_legs WHERE roll_group_id=? ORDER BY leg_index ASC, id ASC", + (gid,), + ).fetchall(): + ld = _row_dict(leg) + try: + from lib.strategy.strategy_roll_monitor_lib import roll_leg_status_label + + ld["status_label"] = roll_leg_status_label(ld.get("status")) + except Exception: + ld["status_label"] = ld.get("status") or "" + legs.append(ld) + m = cfg.get("app_module") + closed_at = ( + m.app_now_str() + if m is not None and hasattr(m, "app_now_str") + else datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") + ) + payload = { + "group": g, + "legs": legs, + "result_label": result_label, + "pnl_amount": pnl_amount, + } + conn.execute( + """INSERT INTO strategy_trade_snapshots ( + strategy_type, source_id, symbol, exchange_symbol, direction, + result_label, status_at_close, opened_at, closed_at, pnl_amount, snapshot_json, created_at + ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?)""", + ( + STRATEGY_ROLL, + gid, + g.get("symbol"), + g.get("exchange_symbol"), + g.get("direction"), + result_label, + g.get("status"), + g.get("created_at"), + closed_at, + pnl_amount, + _json_dumps(payload), + closed_at, + ), + ) + prune_strategy_snapshots(conn, keep=STRATEGY_SNAPSHOTS_MAX_ROWS) + + +def prune_strategy_snapshots(conn, *, keep: int = STRATEGY_SNAPSHOTS_MAX_ROWS) -> None: + """仅保留最近 keep 条策略快照(按 closed_at / id 倒序)。""" + dedupe_strategy_snapshots(conn) + k = max(1, min(int(keep), 500)) + conn.execute( + """DELETE FROM strategy_trade_snapshots + WHERE id NOT IN ( + SELECT id FROM strategy_trade_snapshots + ORDER BY COALESCE(closed_at, created_at, '') DESC, id DESC + LIMIT ? + )""", + (k,), + ) + + +def _snapshot_pnl(row: dict, snap: dict) -> float | None: + for key in ("pnl_amount",): + v = row.get(key) + if v is not None and v != "": + try: + return float(v) + except (TypeError, ValueError): + pass + v = snap.get("pnl_amount") + if v is not None and v != "": + try: + return float(v) + except (TypeError, ValueError): + pass + return None + + +def _trend_dca_stats(snap: dict) -> dict: + levels = snap.get("dca_levels") or build_trend_dca_levels(snap) + dca_only = [ + lv + for lv in levels + if (lv.get("leg_key") or "") != "first" and (lv.get("label") or "") != "首仓" + ] + done = sum(1 for lv in dca_only if lv.get("status") == "done") + total = len(dca_only) + pending = total - done + if total <= 0: + tag = "na" + elif done <= 0: + tag = "no_dca" + elif done >= total: + tag = "dca_done" + else: + tag = "dca_partial" + return { + "dca_done": done, + "dca_total": total, + "dca_pending": pending, + "dca_tag": tag, + } + + +def _roll_leg_stats(snap: dict) -> dict: + legs = snap.get("legs") or [] + if not isinstance(legs, list): + legs = [] + filled = sum(1 for lg in legs if (lg.get("status") or "").lower() == "filled") + total = len(legs) + pending = total - filled + if total <= 0: + tag = "na" + elif filled <= 0: + tag = "no_dca" + elif filled >= total: + tag = "dca_done" + else: + tag = "dca_partial" + return { + "dca_done": filled, + "dca_total": total, + "dca_pending": pending, + "dca_tag": tag, + } + + +def enrich_strategy_snapshot_row(row: dict) -> dict: + d = dict(row or {}) + snap = d.get("snapshot") or {} + st = (d.get("strategy_type") or "").strip() + pnl = _snapshot_pnl(d, snap) + if pnl is not None: + if pnl > 1e-9: + d["filter_pnl"] = "profit" + elif pnl < -1e-9: + d["filter_pnl"] = "loss" + else: + d["filter_pnl"] = "flat" + else: + d["filter_pnl"] = "unknown" + snap_sym = "" + if isinstance(snap, dict): + snap_sym = (snap.get("symbol") or snap.get("exchange_symbol") or "").strip() + sym = (d.get("symbol") or d.get("exchange_symbol") or snap_sym or "").strip() + if sym: + d["symbol"] = d.get("symbol") or sym + d["exchange_symbol"] = d.get("exchange_symbol") or sym + d["filter_symbol"] = sym.upper().split("/")[0].split(":")[0] if sym else "" + closed = (d.get("closed_at") or d.get("created_at") or "").strip() + d["sort_ts"] = closed + if st == STRATEGY_TREND: + stats = _trend_dca_stats(snap) + d.update(stats) + legs_txt = ( + f"{stats['dca_done']}/{stats['dca_total']}" + if stats["dca_total"] > 0 + else "0/0" + ) + d["summary_dca"] = legs_txt + else: + stats = _roll_leg_stats(snap) + d.update(stats) + d["summary_dca"] = ( + f"{stats['dca_done']}/{stats['dca_total']}腿" + if stats["dca_total"] > 0 + else "—" + ) + return d + + +def list_strategy_snapshots(conn, *, limit: int = 200) -> list[dict]: + init_strategy_snapshot_table(conn) + rows = conn.execute( + "SELECT * FROM strategy_trade_snapshots ORDER BY id DESC LIMIT ?", + (max(1, min(int(limit), 500)),), + ).fetchall() + out = [] + seen: dict[tuple[str, int, str], int] = {} + for r in rows: + d = _row_dict(r) + try: + d["snapshot"] = json.loads(d.get("snapshot_json") or "{}") + except Exception: + d["snapshot"] = {} + st = (d.get("strategy_type") or "").strip() + d["strategy_label"] = "趋势回调" if st == STRATEGY_TREND else "顺势加仓" + enriched = enrich_strategy_snapshot_row(d) + try: + source_id = int(enriched.get("source_id") or 0) + except (TypeError, ValueError): + source_id = 0 + result_label = (enriched.get("result_label") or "").strip() + close_rank = _final_trend_close_rank(result_label) + if st == STRATEGY_TREND and source_id > 0 and close_rank > 0: + plan_key = (st, source_id) + snap_id = int(enriched.get("id") or 0) + prev = seen.get(plan_key) + if prev is not None: + prev_id, prev_rank = prev + if prev_rank > close_rank or (prev_rank == close_rank and prev_id >= snap_id): + continue + out = [x for x in out if int(x.get("id") or 0) != prev_id] + seen[plan_key] = (snap_id, close_rank) + out.append(enriched) + continue + key = (st, source_id, result_label) + snap_id = int(enriched.get("id") or 0) + prev = seen.get(key) + if prev is not None and prev[0] >= snap_id: + continue + if prev is not None: + out = [x for x in out if int(x.get("id") or 0) != prev[0]] + seen[key] = (snap_id, 0) + out.append(enriched) + return out + + +def list_strategy_snapshots_split( + conn, *, limit: int = STRATEGY_SNAPSHOTS_MAX_ROWS +) -> tuple[list[dict], list[dict], list[str]]: + """趋势 / 顺势分组,及筛选用币种列表。""" + all_rows = list_strategy_snapshots(conn, limit=limit) + trend = [r for r in all_rows if (r.get("strategy_type") or "") == STRATEGY_TREND] + roll = [r for r in all_rows if (r.get("strategy_type") or "") == STRATEGY_ROLL] + symbols = sorted({r.get("filter_symbol") or "" for r in all_rows if r.get("filter_symbol")}) + return trend, roll, symbols diff --git a/strategy_trade_labels.py b/lib/strategy/strategy_trade_labels.py similarity index 100% rename from strategy_trade_labels.py rename to lib/strategy/strategy_trade_labels.py diff --git a/strategy_trend_exchange.py b/lib/strategy/strategy_trend_exchange.py similarity index 100% rename from strategy_trend_exchange.py rename to lib/strategy/strategy_trend_exchange.py diff --git a/strategy_trend_lib.py b/lib/strategy/strategy_trend_lib.py similarity index 96% rename from strategy_trend_lib.py rename to lib/strategy/strategy_trend_lib.py index 9c0f712..0bbb30d 100644 --- a/strategy_trend_lib.py +++ b/lib/strategy/strategy_trend_lib.py @@ -1,695 +1,695 @@ -"""趋势回调策略:纯计算与校验(无 ccxt / Flask)。各所 adapter 负责张数精度与下单。""" -from __future__ import annotations - -import json -from typing import Any, Callable, Optional, Tuple - -AmountPreciseFn = Callable[[str, float], Optional[float]] - - -def calc_risk_fraction(direction: str, entry_price: float, stop_loss: float) -> Optional[float]: - try: - entry = float(entry_price) - sl = float(stop_loss) - if entry <= 0 or sl <= 0: - return None - if (direction or "long").strip().lower() == "short": - risk = sl - entry - else: - risk = entry - sl - if risk <= 0: - return None - return risk / entry - except (TypeError, ValueError): - return None - - -def trend_effective_margin_capital(plan: dict) -> float: - """按已开仓张数占计划总张数比例折算保证金(首仓/部分补仓时的盈亏估算)。""" - try: - plan_margin = float(plan.get("plan_margin_capital") or 0) - target = float(plan.get("target_order_amount") or 0) - open_amt = float(plan.get("order_amount_open") or 0) - except (TypeError, ValueError): - return float((plan or {}).get("plan_margin_capital") or 0) - if plan_margin <= 0: - return 0.0 - if target > 0 and open_amt > 0: - return round(plan_margin * min(1.0, open_amt / target), 8) - try: - first = float(plan.get("first_order_amount") or 0) - except (TypeError, ValueError): - first = 0.0 - if target > 0 and first > 0: - return round(plan_margin * min(1.0, first / target), 8) - return plan_margin - - -def trend_dca_level_reached(direction: str, mark_price: float, level: float) -> bool: - """做空:价升触达/越过档位即应补仓;做多:价跌触达/越过档位。""" - d = (direction or "long").strip().lower() - try: - pf = float(mark_price) - lv = float(level) - except (TypeError, ValueError): - return False - if d == "long": - return pf <= lv - return pf >= lv - - -def validate_trend_bounds(direction: str, stop_loss: float, add_upper: float) -> Optional[str]: - direction = (direction or "long").strip().lower() - if direction == "long": - if not (float(stop_loss) < float(add_upper)): - return "做多:止损价须低于补仓上沿" - else: - if not (float(stop_loss) > float(add_upper)): - return "做空:止损价须高于补仓下沿" - return None - - -def build_grid_prices(direction: str, sl: float, upper: float, n_legs: int) -> list[float]: - """在 (止损, 补仓区间远侧边界) 内生成 n_legs 个触发价(不含端点)。""" - sl, upper = float(sl), float(upper) - out: list[float] = [] - if n_legs <= 0: - return out - direction = (direction or "long").strip().lower() - if direction == "long": - if upper <= sl: - return out - span = upper - sl - for i in range(1, n_legs + 1): - t = i / float(n_legs + 1) - out.append(sl + t * span) - out.sort(reverse=True) - else: - if sl <= upper: - return out - span = sl - upper - for i in range(1, n_legs + 1): - t = i / float(n_legs + 1) - out.append(upper + t * span) - out.sort() - return [round(p, 10) for p in out] - - -def pick_dca_legs_and_per_leg( - exchange_symbol: str, - remainder_total: float, - want_legs: int, - amount_precise: AmountPreciseFn, - min_amount: float = 0.0, -) -> Tuple[int, float]: - """按最小张数约束自动减少档位数。返回 (有效档数, 每档参考张数)。""" - legs = max(1, int(want_legs)) - rem = float(remainder_total) - min_amt = float(min_amount or 0.0) - while legs >= 1: - per = rem / legs - per_p = amount_precise(exchange_symbol, per) - if per_p is None or per_p <= 0: - legs -= 1 - continue - if min_amt and per_p + 1e-12 < min_amt: - legs -= 1 - continue - return legs, per_p - one = amount_precise(exchange_symbol, rem) - if one is None or one <= 0: - return 0, 0.0 - return 1, one - - -def build_leg_amounts_json( - exchange_symbol: str, - remainder_total: float, - want_legs: int, - amount_precise: AmountPreciseFn, - min_amount: float = 0.0, -) -> Tuple[int, str, float]: - """拆分补仓张数 JSON。返回 (档位数, json列表, 每档参考)。""" - rem = amount_precise(exchange_symbol, float(remainder_total)) - if rem is None or rem <= 0: - return 0, "[]", 0.0 - n, _ = pick_dca_legs_and_per_leg(exchange_symbol, rem, want_legs, amount_precise, min_amount) - if n <= 0: - return 0, "[]", 0.0 - if n <= 1: - one = amount_precise(exchange_symbol, rem) - if one is None or one <= 0: - return 0, "[]", 0.0 - return 1, json.dumps([one]), one - unit = amount_precise(exchange_symbol, rem / n) - if unit is None or unit <= 0: - one = amount_precise(exchange_symbol, rem) - if one is None or one <= 0: - return 0, "[]", 0.0 - return 1, json.dumps([one]), one - parts: list[float] = [] - acc = 0.0 - for _ in range(n - 1): - parts.append(unit) - acc += unit - last = amount_precise(exchange_symbol, max(0.0, rem - acc)) - if last is None or last <= 0: - one = amount_precise(exchange_symbol, rem) - if one is None or one <= 0: - return 0, "[]", 0.0 - return 1, json.dumps([one]), one - parts.append(last) - return n, json.dumps(parts), unit - - -def compute_trend_plan_core( - *, - direction: str, - stop_loss: float, - add_upper: float, - risk_percent: float, - snapshot_usdt: float, - leverage: int, - live_price: float, - target_order_amount: float, - exchange_symbol: str, - dca_legs: int, - amount_precise: AmountPreciseFn, - min_amount: float = 0.0, - full_margin_buffer_ratio: float = 0.95, -) -> Tuple[Optional[dict[str, Any]], Optional[str]]: - """在已有 target_order_amount 时组装预览 payload(张数由调用方 prepare_order_amount 计算)。""" - rf = calc_risk_fraction(direction, add_upper, stop_loss) - if rf is None or rf <= 0: - return None, "止损与补仓区间边界组合无法计算风险比例" - risk_budget = float(snapshot_usdt) * (float(risk_percent) / 100.0) - notional = risk_budget / rf - margin_plan = notional / float(leverage) - margin_plan = min(margin_plan, float(snapshot_usdt) * float(full_margin_buffer_ratio)) - if margin_plan <= 0: - return None, "计划保证金过小" - first_amt = amount_precise(exchange_symbol, float(target_order_amount) * 0.5) - if first_amt is None or first_amt <= 0: - return None, "首仓张数过小(低于交易所最小张数),请提高风险比例或杠杆" - remainder_total = amount_precise(exchange_symbol, max(0.0, float(target_order_amount) - float(first_amt))) - if remainder_total is None: - remainder_total = 0.0 - n_legs, leg_json, per_ref = build_leg_amounts_json( - exchange_symbol, remainder_total, dca_legs, amount_precise, min_amount - ) - if n_legs <= 0: - return None, "剩余计划张数不足以拆出补仓档,请提高风险比例或放宽止损与补仓区间间距" - grid = build_grid_prices(direction, stop_loss, add_upper, n_legs) - if len(grid) != n_legs: - return None, "补仓网格生成失败" - try: - leg_list = json.loads(leg_json) - except Exception: - leg_list = [] - payload = { - "direction": direction, - "stop_loss": float(stop_loss), - "add_upper": float(add_upper), - "risk_percent": float(risk_percent), - "snapshot_available_usdt": float(snapshot_usdt), - "live_price_ref": float(live_price), - "plan_margin_capital": float(margin_plan), - "target_order_amount": float(target_order_amount), - "first_order_amount": float(first_amt), - "remainder_total": float(remainder_total), - "dca_legs": int(n_legs), - "per_leg_amount": float(per_ref), - "grid_prices_json": json.dumps(grid), - "leg_amounts_json": leg_json, - "grid": grid, - "leg_amounts": leg_list, - } - return payload, None - - -def calc_planned_reward_risk_ratio( - direction: str, entry_price: float, stop_loss: float, take_profit: float -) -> Optional[float]: - """盈亏比(reward/risk),与四所 calc_rr_ratio 口径一致。""" - try: - entry = float(entry_price) - sl = float(stop_loss) - tp = float(take_profit) - if entry <= 0 or sl <= 0 or tp <= 0: - return None - direction = (direction or "long").strip().lower() - if direction == "short": - risk = sl - entry - reward = entry - tp - else: - risk = entry - sl - reward = tp - entry - if risk <= 0 or reward <= 0: - return None - return round(reward / risk, 4) - except (TypeError, ValueError): - return None - - -def calc_take_profit_for_rr( - direction: str, entry_price: float, stop_loss: float, reward_risk_ratio: float -) -> Optional[float]: - """按统一止损与目标 RR 反推止盈价。""" - try: - entry = float(entry_price) - sl = float(stop_loss) - rr = float(reward_risk_ratio) - if entry <= 0 or sl <= 0 or rr <= 0: - return None - direction = (direction or "long").strip().lower() - if direction == "short": - risk = sl - entry - if risk <= 0: - return None - return round(entry - rr * risk, 10) - risk = entry - sl - if risk <= 0: - return None - return round(entry + rr * risk, 10) - except (TypeError, ValueError): - return None - - -def calc_risk_budget_usdt(snapshot_usdt: float, risk_percent: float) -> Optional[float]: - """计划止损金额 U = 可用快照 × 风险比例。""" - try: - snap = float(snapshot_usdt) - rp = float(risk_percent) - if snap <= 0 or rp <= 0: - return None - return round(snap * rp / 100.0, 4) - except (TypeError, ValueError): - return None - - -def calc_money_reward_risk_ratio(profit_u: float, risk_u: float) -> Optional[float]: - """金额盈亏比 = 止盈盈利 U / 止损金额 U。""" - try: - r = float(risk_u) - p = float(profit_u) - if r <= 0: - return None - return round(p / r, 4) - except (TypeError, ValueError): - return None - - -def calc_tp_profit_usdt( - direction: str, - avg_entry: float, - take_profit_price: float, - contracts: float, - contract_size: float = 1.0, -) -> Optional[float]: - """到达止盈价时,按累计张数与加仓后均价的盈利 U。""" - try: - from hub_position_metrics import estimate_linear_swap_upnl_usdt - - return estimate_linear_swap_upnl_usdt( - direction, float(avg_entry), float(take_profit_price), float(contracts), float(contract_size) - ) - except (TypeError, ValueError): - return None - - -def weighted_avg_entry(legs: list[tuple[float, float]]) -> Optional[float]: - """按 (成交价, 张数) 加权均价。""" - total = 0.0 - cost = 0.0 - for price, amount in legs or []: - try: - p = float(price) - a = float(amount) - except (TypeError, ValueError): - continue - if a <= 0: - continue - total += a - cost += p * a - if total <= 0: - return None - return cost / total - - -def parse_leg_fill_prices(plan: dict) -> list[float]: - """首仓 + 各档补仓实际成交价列表。""" - try: - raw = json.loads((plan or {}).get("leg_fill_prices_json") or "[]") - if not isinstance(raw, list): - return [] - out: list[float] = [] - for item in raw: - try: - out.append(float(item)) - except (TypeError, ValueError): - continue - return out - except Exception: - return [] - - -def append_leg_fill_price_json(existing_json: str | None, fill_px: float) -> str: - fills = parse_leg_fill_prices({"leg_fill_prices_json": existing_json}) - fills.append(float(fill_px)) - return json.dumps(fills, ensure_ascii=False, separators=(",", ":")) - - -def trend_leg_grid_price(plan: dict, leg_idx: int) -> Optional[float]: - """补仓 leg_idx(1..N) 的计划网格触发价;首仓返回 None。""" - if leg_idx <= 0: - return None - try: - grid = [float(x) for x in json.loads((plan or {}).get("grid_prices_json") or "[]")] - except Exception: - grid = [] - gi = leg_idx - 1 - if 0 <= gi < len(grid): - return float(grid[gi]) - return None - - -def trend_leg_display_price(plan: dict, leg_idx: int) -> Optional[float]: - """ - 四所统一:单档展示价 = leg_fill_prices_json 实际记录,否则计划网格(首仓用均价/参考价)。 - 禁止为凑均价反推虚构成交价。 - """ - p = plan or {} - fills = parse_leg_fill_prices(p) - if len(fills) > leg_idx: - return float(fills[leg_idx]) - if leg_idx == 0: - try: - return float(p.get("avg_entry_price")) - except (TypeError, ValueError): - pass - try: - ref = p.get("live_price_ref") - if ref not in (None, ""): - return float(ref) - except (TypeError, ValueError): - pass - return None - return trend_leg_grid_price(p, leg_idx) - - -def reconcile_trend_leg_fill_prices(plan: dict) -> list[float]: - """首仓(0)+已补仓(1..legs_done) 展示价列表(四所共用 trend_leg_display_price)。""" - p = plan or {} - if int(p.get("first_order_done") or 0) == 0: - return [] - try: - legs_done = int(p.get("legs_done") or 0) - except (TypeError, ValueError): - legs_done = 0 - result: list[float] = [] - for leg_idx in range(legs_done + 1): - px = trend_leg_display_price(p, leg_idx) - result.append(float(px) if px is not None else 0.0) - return result - - -def calc_trend_plan_money_metrics(plan: dict) -> dict: - """运行中计划头部:按快照风险金额计算盈亏比(止盈盈利 U / 风险 U)。""" - out = {"money_rr": None, "risk_amount_u": None} - p = plan or {} - try: - direction = (p.get("direction") or "long").strip().lower() - user_tp = float(p.get("take_profit")) - avg = float(p.get("avg_entry_price")) - open_amt = float(p.get("order_amount_open") or p.get("first_order_amount") or 0) - snapshot = float(p.get("snapshot_available_usdt")) - risk_percent = float(p.get("risk_percent")) - except (TypeError, ValueError): - return out - if avg <= 0 or open_amt <= 0: - return out - risk_u = calc_risk_budget_usdt(snapshot, risk_percent) - if risk_u is None or risk_u <= 0: - return out - out["risk_amount_u"] = risk_u - try: - contract_size = float(p.get("contract_size") or 1.0) - if contract_size <= 0: - contract_size = 1.0 - except (TypeError, ValueError): - contract_size = 1.0 - profit_u = calc_tp_profit_usdt(direction, avg, user_tp, open_amt, contract_size) - out["money_rr"] = calc_money_reward_risk_ratio(profit_u, risk_u) - return out - - -def build_trend_preview_level_rows(preview: dict) -> tuple[dict, list[dict]]: - """ - 预览:表单止盈价下每档累计持仓的盈利 U;止损金额 = 快照×风险;盈亏比按金额对比。 - 返回 (增强后的 preview 字段, 表格行列表,含首仓行)。 - """ - p = dict(preview or {}) - direction = (p.get("direction") or "long").strip().lower() - try: - ref = float(p.get("live_price_ref")) - sl = float(p.get("stop_loss")) - user_tp = float(p.get("take_profit")) - first_amt = float(p.get("first_order_amount")) - snapshot = float(p.get("snapshot_available_usdt")) - risk_percent = float(p.get("risk_percent")) - except (TypeError, ValueError): - return p, [] - - risk_u = calc_risk_budget_usdt(snapshot, risk_percent) - if risk_u is None or risk_u <= 0: - return p, [] - - try: - contract_size = float(p.get("contract_size") or 1.0) - if contract_size <= 0: - contract_size = 1.0 - except (TypeError, ValueError): - contract_size = 1.0 - - p["preview_risk_amount_u"] = risk_u - p["preview_take_profit_price"] = user_tp - p["preview_unified_stop_loss"] = sl - - try: - grid = json.loads(p.get("grid_prices_json") or "[]") - if not isinstance(grid, list): - grid = [] - except Exception: - grid = [] - try: - leg_amounts = json.loads(p.get("leg_amounts_json") or "[]") - if not isinstance(leg_amounts, list): - leg_amounts = [] - except Exception: - leg_amounts = [] - - def _row_dict( - *, - i: int, - label: str, - price: float, - leg_contracts: float, - cum_contracts: float, - avg: float, - is_first: bool, - ) -> dict: - profit_u = calc_tp_profit_usdt(direction, avg, user_tp, cum_contracts, contract_size) - rr_money = calc_money_reward_risk_ratio(profit_u, risk_u) if profit_u is not None else None - return { - "i": i, - "label": label, - "price": price, - "contracts": leg_contracts, - "cum_contracts": cum_contracts, - "avg_entry": avg, - "take_profit_price": user_tp, - "profit_u": profit_u, - "risk_u": risk_u, - "rr": rr_money, - "stop_loss_price": sl, - "take_profit": profit_u, - "stop_loss": risk_u, - "is_first": is_first, - } - - cum_contracts = first_amt - first_profit = calc_tp_profit_usdt(direction, ref, user_tp, cum_contracts, contract_size) - first_rr = calc_money_reward_risk_ratio(first_profit, risk_u) if first_profit is not None else None - p["preview_first_profit_u"] = first_profit - p["preview_target_rr"] = first_rr - p["preview_first_take_profit"] = user_tp - - rows: list[dict] = [ - _row_dict( - i=0, - label="首仓", - price=ref, - leg_contracts=first_amt, - cum_contracts=cum_contracts, - avg=ref, - is_first=True, - ) - ] - accumulated: list[tuple[float, float]] = [(ref, first_amt)] - for i, pair in enumerate(zip(grid, leg_amounts), 1): - try: - price = float(pair[0]) - leg_contracts = float(pair[1]) - except (TypeError, ValueError): - continue - accumulated.append((price, leg_contracts)) - avg = weighted_avg_entry(accumulated) - if avg is None: - continue - cum_contracts += leg_contracts - rows.append( - _row_dict( - i=i, - label=f"补仓{i}", - price=price, - leg_contracts=leg_contracts, - cum_contracts=cum_contracts, - avg=avg, - is_first=False, - ) - ) - return p, rows - - -def enrich_trend_dca_levels_with_tp(plan: dict, levels: list[dict]) -> list[dict]: - """ - 四所统一补仓表 enrich(实例策略页 + 中控 monitor 共用)。 - 触发价:实际成交价或计划网格;末档加仓后均价用持仓均价;禁止反推虚构成交价。 - """ - if not levels: - return levels - p = plan or {} - direction = (p.get("direction") or "long").strip().lower() - try: - sl = float(p.get("stop_loss")) - user_tp = float(p.get("take_profit")) - first_amt = float(p.get("first_order_amount")) - snapshot = float(p.get("snapshot_available_usdt")) - risk_percent = float(p.get("risk_percent")) - except (TypeError, ValueError): - return levels - - risk_u = calc_risk_budget_usdt(snapshot, risk_percent) - if risk_u is None or risk_u <= 0: - return levels - - try: - legs_done = int(p.get("legs_done") or 0) - except (TypeError, ValueError): - legs_done = 0 - first_done = int(p.get("first_order_done") or 0) != 0 - try: - target_avg = float(p.get("avg_entry_price")) - except (TypeError, ValueError): - target_avg = None - - ref_raw = p.get("live_price_ref") - if ref_raw in (None, ""): - ref_raw = p.get("avg_entry_price") - try: - ref = float(ref_raw) - except (TypeError, ValueError): - return levels - - try: - contract_size = float(p.get("contract_size") or 1.0) - if contract_size <= 0: - contract_size = 1.0 - except (TypeError, ValueError): - contract_size = 1.0 - - out: list[dict] = [] - accumulated: list[tuple[float, float]] = [] - cum_contracts = 0.0 - for lv in levels: - row = dict(lv) - is_first = row.get("leg_key") == "first" or row.get("label") == "首仓" or row.get("i") == 0 - row_cum = cum_contracts - if is_first: - try: - amt_f = float(row.get("contracts") if row.get("contracts") is not None else first_amt) - except (TypeError, ValueError): - amt_f = first_amt - if first_done: - fill_px = trend_leg_display_price(p, 0) - if fill_px is None: - try: - fill_px = float(p.get("avg_entry_price") or ref) - except (TypeError, ValueError): - fill_px = ref - accumulated = [(float(fill_px), amt_f)] - cum_contracts = amt_f - row_cum = cum_contracts - row["price"] = fill_px - if target_avg is not None and legs_done == 0: - row["avg_entry"] = target_avg - else: - row["avg_entry"] = float(fill_px) - else: - accumulated = [(ref, amt_f)] - cum_contracts = amt_f - row_cum = cum_contracts - row["avg_entry"] = ref - else: - try: - leg_num = int(row.get("i") or 0) - except (TypeError, ValueError): - leg_num = 0 - grid_trigger = row.get("price") - try: - grid_trigger_f = float(grid_trigger) if grid_trigger is not None else None - except (TypeError, ValueError): - grid_trigger_f = None - try: - leg_contracts = float(row.get("contracts") or 0) - except (TypeError, ValueError): - leg_contracts = 0.0 - done = row.get("status") == "done" or (leg_num > 0 and leg_num <= legs_done) - if done and leg_contracts > 0: - fill_px = trend_leg_display_price(p, leg_num) - if fill_px is None: - fill_px = grid_trigger_f if grid_trigger_f is not None else ref - row["price"] = fill_px - accumulated.append((fill_px, leg_contracts)) - cum_contracts += leg_contracts - row_cum = cum_contracts - if leg_num == legs_done and target_avg is not None: - row["avg_entry"] = target_avg - else: - avg = weighted_avg_entry(accumulated) - if avg is not None: - row["avg_entry"] = avg - elif grid_trigger_f is not None and leg_contracts > 0: - row["price"] = grid_trigger_f - projected = accumulated + [(grid_trigger_f, leg_contracts)] - avg = weighted_avg_entry(projected) - if avg is not None: - row["avg_entry"] = avg - row_cum = cum_contracts + leg_contracts - elif grid_trigger_f is not None: - row["price"] = grid_trigger_f - - avg_entry = row.get("avg_entry") - if avg_entry is not None and row_cum > 0: - profit_u = calc_tp_profit_usdt( - direction, float(avg_entry), user_tp, row_cum, contract_size - ) - row["take_profit_price"] = user_tp - row["profit_u"] = profit_u - row["risk_u"] = risk_u - row["rr"] = calc_money_reward_risk_ratio(profit_u, risk_u) if profit_u is not None else None - row["take_profit"] = profit_u - row["stop_loss"] = risk_u - row["stop_loss_price"] = sl - out.append(row) - return out +"""趋势回调策略:纯计算与校验(无 ccxt / Flask)。各所 adapter 负责张数精度与下单。""" +from __future__ import annotations + +import json +from typing import Any, Callable, Optional, Tuple + +AmountPreciseFn = Callable[[str, float], Optional[float]] + + +def calc_risk_fraction(direction: str, entry_price: float, stop_loss: float) -> Optional[float]: + try: + entry = float(entry_price) + sl = float(stop_loss) + if entry <= 0 or sl <= 0: + return None + if (direction or "long").strip().lower() == "short": + risk = sl - entry + else: + risk = entry - sl + if risk <= 0: + return None + return risk / entry + except (TypeError, ValueError): + return None + + +def trend_effective_margin_capital(plan: dict) -> float: + """按已开仓张数占计划总张数比例折算保证金(首仓/部分补仓时的盈亏估算)。""" + try: + plan_margin = float(plan.get("plan_margin_capital") or 0) + target = float(plan.get("target_order_amount") or 0) + open_amt = float(plan.get("order_amount_open") or 0) + except (TypeError, ValueError): + return float((plan or {}).get("plan_margin_capital") or 0) + if plan_margin <= 0: + return 0.0 + if target > 0 and open_amt > 0: + return round(plan_margin * min(1.0, open_amt / target), 8) + try: + first = float(plan.get("first_order_amount") or 0) + except (TypeError, ValueError): + first = 0.0 + if target > 0 and first > 0: + return round(plan_margin * min(1.0, first / target), 8) + return plan_margin + + +def trend_dca_level_reached(direction: str, mark_price: float, level: float) -> bool: + """做空:价升触达/越过档位即应补仓;做多:价跌触达/越过档位。""" + d = (direction or "long").strip().lower() + try: + pf = float(mark_price) + lv = float(level) + except (TypeError, ValueError): + return False + if d == "long": + return pf <= lv + return pf >= lv + + +def validate_trend_bounds(direction: str, stop_loss: float, add_upper: float) -> Optional[str]: + direction = (direction or "long").strip().lower() + if direction == "long": + if not (float(stop_loss) < float(add_upper)): + return "做多:止损价须低于补仓上沿" + else: + if not (float(stop_loss) > float(add_upper)): + return "做空:止损价须高于补仓下沿" + return None + + +def build_grid_prices(direction: str, sl: float, upper: float, n_legs: int) -> list[float]: + """在 (止损, 补仓区间远侧边界) 内生成 n_legs 个触发价(不含端点)。""" + sl, upper = float(sl), float(upper) + out: list[float] = [] + if n_legs <= 0: + return out + direction = (direction or "long").strip().lower() + if direction == "long": + if upper <= sl: + return out + span = upper - sl + for i in range(1, n_legs + 1): + t = i / float(n_legs + 1) + out.append(sl + t * span) + out.sort(reverse=True) + else: + if sl <= upper: + return out + span = sl - upper + for i in range(1, n_legs + 1): + t = i / float(n_legs + 1) + out.append(upper + t * span) + out.sort() + return [round(p, 10) for p in out] + + +def pick_dca_legs_and_per_leg( + exchange_symbol: str, + remainder_total: float, + want_legs: int, + amount_precise: AmountPreciseFn, + min_amount: float = 0.0, +) -> Tuple[int, float]: + """按最小张数约束自动减少档位数。返回 (有效档数, 每档参考张数)。""" + legs = max(1, int(want_legs)) + rem = float(remainder_total) + min_amt = float(min_amount or 0.0) + while legs >= 1: + per = rem / legs + per_p = amount_precise(exchange_symbol, per) + if per_p is None or per_p <= 0: + legs -= 1 + continue + if min_amt and per_p + 1e-12 < min_amt: + legs -= 1 + continue + return legs, per_p + one = amount_precise(exchange_symbol, rem) + if one is None or one <= 0: + return 0, 0.0 + return 1, one + + +def build_leg_amounts_json( + exchange_symbol: str, + remainder_total: float, + want_legs: int, + amount_precise: AmountPreciseFn, + min_amount: float = 0.0, +) -> Tuple[int, str, float]: + """拆分补仓张数 JSON。返回 (档位数, json列表, 每档参考)。""" + rem = amount_precise(exchange_symbol, float(remainder_total)) + if rem is None or rem <= 0: + return 0, "[]", 0.0 + n, _ = pick_dca_legs_and_per_leg(exchange_symbol, rem, want_legs, amount_precise, min_amount) + if n <= 0: + return 0, "[]", 0.0 + if n <= 1: + one = amount_precise(exchange_symbol, rem) + if one is None or one <= 0: + return 0, "[]", 0.0 + return 1, json.dumps([one]), one + unit = amount_precise(exchange_symbol, rem / n) + if unit is None or unit <= 0: + one = amount_precise(exchange_symbol, rem) + if one is None or one <= 0: + return 0, "[]", 0.0 + return 1, json.dumps([one]), one + parts: list[float] = [] + acc = 0.0 + for _ in range(n - 1): + parts.append(unit) + acc += unit + last = amount_precise(exchange_symbol, max(0.0, rem - acc)) + if last is None or last <= 0: + one = amount_precise(exchange_symbol, rem) + if one is None or one <= 0: + return 0, "[]", 0.0 + return 1, json.dumps([one]), one + parts.append(last) + return n, json.dumps(parts), unit + + +def compute_trend_plan_core( + *, + direction: str, + stop_loss: float, + add_upper: float, + risk_percent: float, + snapshot_usdt: float, + leverage: int, + live_price: float, + target_order_amount: float, + exchange_symbol: str, + dca_legs: int, + amount_precise: AmountPreciseFn, + min_amount: float = 0.0, + full_margin_buffer_ratio: float = 0.95, +) -> Tuple[Optional[dict[str, Any]], Optional[str]]: + """在已有 target_order_amount 时组装预览 payload(张数由调用方 prepare_order_amount 计算)。""" + rf = calc_risk_fraction(direction, add_upper, stop_loss) + if rf is None or rf <= 0: + return None, "止损与补仓区间边界组合无法计算风险比例" + risk_budget = float(snapshot_usdt) * (float(risk_percent) / 100.0) + notional = risk_budget / rf + margin_plan = notional / float(leverage) + margin_plan = min(margin_plan, float(snapshot_usdt) * float(full_margin_buffer_ratio)) + if margin_plan <= 0: + return None, "计划保证金过小" + first_amt = amount_precise(exchange_symbol, float(target_order_amount) * 0.5) + if first_amt is None or first_amt <= 0: + return None, "首仓张数过小(低于交易所最小张数),请提高风险比例或杠杆" + remainder_total = amount_precise(exchange_symbol, max(0.0, float(target_order_amount) - float(first_amt))) + if remainder_total is None: + remainder_total = 0.0 + n_legs, leg_json, per_ref = build_leg_amounts_json( + exchange_symbol, remainder_total, dca_legs, amount_precise, min_amount + ) + if n_legs <= 0: + return None, "剩余计划张数不足以拆出补仓档,请提高风险比例或放宽止损与补仓区间间距" + grid = build_grid_prices(direction, stop_loss, add_upper, n_legs) + if len(grid) != n_legs: + return None, "补仓网格生成失败" + try: + leg_list = json.loads(leg_json) + except Exception: + leg_list = [] + payload = { + "direction": direction, + "stop_loss": float(stop_loss), + "add_upper": float(add_upper), + "risk_percent": float(risk_percent), + "snapshot_available_usdt": float(snapshot_usdt), + "live_price_ref": float(live_price), + "plan_margin_capital": float(margin_plan), + "target_order_amount": float(target_order_amount), + "first_order_amount": float(first_amt), + "remainder_total": float(remainder_total), + "dca_legs": int(n_legs), + "per_leg_amount": float(per_ref), + "grid_prices_json": json.dumps(grid), + "leg_amounts_json": leg_json, + "grid": grid, + "leg_amounts": leg_list, + } + return payload, None + + +def calc_planned_reward_risk_ratio( + direction: str, entry_price: float, stop_loss: float, take_profit: float +) -> Optional[float]: + """盈亏比(reward/risk),与四所 calc_rr_ratio 口径一致。""" + try: + entry = float(entry_price) + sl = float(stop_loss) + tp = float(take_profit) + if entry <= 0 or sl <= 0 or tp <= 0: + return None + direction = (direction or "long").strip().lower() + if direction == "short": + risk = sl - entry + reward = entry - tp + else: + risk = entry - sl + reward = tp - entry + if risk <= 0 or reward <= 0: + return None + return round(reward / risk, 4) + except (TypeError, ValueError): + return None + + +def calc_take_profit_for_rr( + direction: str, entry_price: float, stop_loss: float, reward_risk_ratio: float +) -> Optional[float]: + """按统一止损与目标 RR 反推止盈价。""" + try: + entry = float(entry_price) + sl = float(stop_loss) + rr = float(reward_risk_ratio) + if entry <= 0 or sl <= 0 or rr <= 0: + return None + direction = (direction or "long").strip().lower() + if direction == "short": + risk = sl - entry + if risk <= 0: + return None + return round(entry - rr * risk, 10) + risk = entry - sl + if risk <= 0: + return None + return round(entry + rr * risk, 10) + except (TypeError, ValueError): + return None + + +def calc_risk_budget_usdt(snapshot_usdt: float, risk_percent: float) -> Optional[float]: + """计划止损金额 U = 可用快照 × 风险比例。""" + try: + snap = float(snapshot_usdt) + rp = float(risk_percent) + if snap <= 0 or rp <= 0: + return None + return round(snap * rp / 100.0, 4) + except (TypeError, ValueError): + return None + + +def calc_money_reward_risk_ratio(profit_u: float, risk_u: float) -> Optional[float]: + """金额盈亏比 = 止盈盈利 U / 止损金额 U。""" + try: + r = float(risk_u) + p = float(profit_u) + if r <= 0: + return None + return round(p / r, 4) + except (TypeError, ValueError): + return None + + +def calc_tp_profit_usdt( + direction: str, + avg_entry: float, + take_profit_price: float, + contracts: float, + contract_size: float = 1.0, +) -> Optional[float]: + """到达止盈价时,按累计张数与加仓后均价的盈利 U。""" + try: + from lib.hub.hub_position_metrics import estimate_linear_swap_upnl_usdt + + return estimate_linear_swap_upnl_usdt( + direction, float(avg_entry), float(take_profit_price), float(contracts), float(contract_size) + ) + except (TypeError, ValueError): + return None + + +def weighted_avg_entry(legs: list[tuple[float, float]]) -> Optional[float]: + """按 (成交价, 张数) 加权均价。""" + total = 0.0 + cost = 0.0 + for price, amount in legs or []: + try: + p = float(price) + a = float(amount) + except (TypeError, ValueError): + continue + if a <= 0: + continue + total += a + cost += p * a + if total <= 0: + return None + return cost / total + + +def parse_leg_fill_prices(plan: dict) -> list[float]: + """首仓 + 各档补仓实际成交价列表。""" + try: + raw = json.loads((plan or {}).get("leg_fill_prices_json") or "[]") + if not isinstance(raw, list): + return [] + out: list[float] = [] + for item in raw: + try: + out.append(float(item)) + except (TypeError, ValueError): + continue + return out + except Exception: + return [] + + +def append_leg_fill_price_json(existing_json: str | None, fill_px: float) -> str: + fills = parse_leg_fill_prices({"leg_fill_prices_json": existing_json}) + fills.append(float(fill_px)) + return json.dumps(fills, ensure_ascii=False, separators=(",", ":")) + + +def trend_leg_grid_price(plan: dict, leg_idx: int) -> Optional[float]: + """补仓 leg_idx(1..N) 的计划网格触发价;首仓返回 None。""" + if leg_idx <= 0: + return None + try: + grid = [float(x) for x in json.loads((plan or {}).get("grid_prices_json") or "[]")] + except Exception: + grid = [] + gi = leg_idx - 1 + if 0 <= gi < len(grid): + return float(grid[gi]) + return None + + +def trend_leg_display_price(plan: dict, leg_idx: int) -> Optional[float]: + """ + 四所统一:单档展示价 = leg_fill_prices_json 实际记录,否则计划网格(首仓用均价/参考价)。 + 禁止为凑均价反推虚构成交价。 + """ + p = plan or {} + fills = parse_leg_fill_prices(p) + if len(fills) > leg_idx: + return float(fills[leg_idx]) + if leg_idx == 0: + try: + return float(p.get("avg_entry_price")) + except (TypeError, ValueError): + pass + try: + ref = p.get("live_price_ref") + if ref not in (None, ""): + return float(ref) + except (TypeError, ValueError): + pass + return None + return trend_leg_grid_price(p, leg_idx) + + +def reconcile_trend_leg_fill_prices(plan: dict) -> list[float]: + """首仓(0)+已补仓(1..legs_done) 展示价列表(四所共用 trend_leg_display_price)。""" + p = plan or {} + if int(p.get("first_order_done") or 0) == 0: + return [] + try: + legs_done = int(p.get("legs_done") or 0) + except (TypeError, ValueError): + legs_done = 0 + result: list[float] = [] + for leg_idx in range(legs_done + 1): + px = trend_leg_display_price(p, leg_idx) + result.append(float(px) if px is not None else 0.0) + return result + + +def calc_trend_plan_money_metrics(plan: dict) -> dict: + """运行中计划头部:按快照风险金额计算盈亏比(止盈盈利 U / 风险 U)。""" + out = {"money_rr": None, "risk_amount_u": None} + p = plan or {} + try: + direction = (p.get("direction") or "long").strip().lower() + user_tp = float(p.get("take_profit")) + avg = float(p.get("avg_entry_price")) + open_amt = float(p.get("order_amount_open") or p.get("first_order_amount") or 0) + snapshot = float(p.get("snapshot_available_usdt")) + risk_percent = float(p.get("risk_percent")) + except (TypeError, ValueError): + return out + if avg <= 0 or open_amt <= 0: + return out + risk_u = calc_risk_budget_usdt(snapshot, risk_percent) + if risk_u is None or risk_u <= 0: + return out + out["risk_amount_u"] = risk_u + try: + contract_size = float(p.get("contract_size") or 1.0) + if contract_size <= 0: + contract_size = 1.0 + except (TypeError, ValueError): + contract_size = 1.0 + profit_u = calc_tp_profit_usdt(direction, avg, user_tp, open_amt, contract_size) + out["money_rr"] = calc_money_reward_risk_ratio(profit_u, risk_u) + return out + + +def build_trend_preview_level_rows(preview: dict) -> tuple[dict, list[dict]]: + """ + 预览:表单止盈价下每档累计持仓的盈利 U;止损金额 = 快照×风险;盈亏比按金额对比。 + 返回 (增强后的 preview 字段, 表格行列表,含首仓行)。 + """ + p = dict(preview or {}) + direction = (p.get("direction") or "long").strip().lower() + try: + ref = float(p.get("live_price_ref")) + sl = float(p.get("stop_loss")) + user_tp = float(p.get("take_profit")) + first_amt = float(p.get("first_order_amount")) + snapshot = float(p.get("snapshot_available_usdt")) + risk_percent = float(p.get("risk_percent")) + except (TypeError, ValueError): + return p, [] + + risk_u = calc_risk_budget_usdt(snapshot, risk_percent) + if risk_u is None or risk_u <= 0: + return p, [] + + try: + contract_size = float(p.get("contract_size") or 1.0) + if contract_size <= 0: + contract_size = 1.0 + except (TypeError, ValueError): + contract_size = 1.0 + + p["preview_risk_amount_u"] = risk_u + p["preview_take_profit_price"] = user_tp + p["preview_unified_stop_loss"] = sl + + try: + grid = json.loads(p.get("grid_prices_json") or "[]") + if not isinstance(grid, list): + grid = [] + except Exception: + grid = [] + try: + leg_amounts = json.loads(p.get("leg_amounts_json") or "[]") + if not isinstance(leg_amounts, list): + leg_amounts = [] + except Exception: + leg_amounts = [] + + def _row_dict( + *, + i: int, + label: str, + price: float, + leg_contracts: float, + cum_contracts: float, + avg: float, + is_first: bool, + ) -> dict: + profit_u = calc_tp_profit_usdt(direction, avg, user_tp, cum_contracts, contract_size) + rr_money = calc_money_reward_risk_ratio(profit_u, risk_u) if profit_u is not None else None + return { + "i": i, + "label": label, + "price": price, + "contracts": leg_contracts, + "cum_contracts": cum_contracts, + "avg_entry": avg, + "take_profit_price": user_tp, + "profit_u": profit_u, + "risk_u": risk_u, + "rr": rr_money, + "stop_loss_price": sl, + "take_profit": profit_u, + "stop_loss": risk_u, + "is_first": is_first, + } + + cum_contracts = first_amt + first_profit = calc_tp_profit_usdt(direction, ref, user_tp, cum_contracts, contract_size) + first_rr = calc_money_reward_risk_ratio(first_profit, risk_u) if first_profit is not None else None + p["preview_first_profit_u"] = first_profit + p["preview_target_rr"] = first_rr + p["preview_first_take_profit"] = user_tp + + rows: list[dict] = [ + _row_dict( + i=0, + label="首仓", + price=ref, + leg_contracts=first_amt, + cum_contracts=cum_contracts, + avg=ref, + is_first=True, + ) + ] + accumulated: list[tuple[float, float]] = [(ref, first_amt)] + for i, pair in enumerate(zip(grid, leg_amounts), 1): + try: + price = float(pair[0]) + leg_contracts = float(pair[1]) + except (TypeError, ValueError): + continue + accumulated.append((price, leg_contracts)) + avg = weighted_avg_entry(accumulated) + if avg is None: + continue + cum_contracts += leg_contracts + rows.append( + _row_dict( + i=i, + label=f"补仓{i}", + price=price, + leg_contracts=leg_contracts, + cum_contracts=cum_contracts, + avg=avg, + is_first=False, + ) + ) + return p, rows + + +def enrich_trend_dca_levels_with_tp(plan: dict, levels: list[dict]) -> list[dict]: + """ + 四所统一补仓表 enrich(实例策略页 + 中控 monitor 共用)。 + 触发价:实际成交价或计划网格;末档加仓后均价用持仓均价;禁止反推虚构成交价。 + """ + if not levels: + return levels + p = plan or {} + direction = (p.get("direction") or "long").strip().lower() + try: + sl = float(p.get("stop_loss")) + user_tp = float(p.get("take_profit")) + first_amt = float(p.get("first_order_amount")) + snapshot = float(p.get("snapshot_available_usdt")) + risk_percent = float(p.get("risk_percent")) + except (TypeError, ValueError): + return levels + + risk_u = calc_risk_budget_usdt(snapshot, risk_percent) + if risk_u is None or risk_u <= 0: + return levels + + try: + legs_done = int(p.get("legs_done") or 0) + except (TypeError, ValueError): + legs_done = 0 + first_done = int(p.get("first_order_done") or 0) != 0 + try: + target_avg = float(p.get("avg_entry_price")) + except (TypeError, ValueError): + target_avg = None + + ref_raw = p.get("live_price_ref") + if ref_raw in (None, ""): + ref_raw = p.get("avg_entry_price") + try: + ref = float(ref_raw) + except (TypeError, ValueError): + return levels + + try: + contract_size = float(p.get("contract_size") or 1.0) + if contract_size <= 0: + contract_size = 1.0 + except (TypeError, ValueError): + contract_size = 1.0 + + out: list[dict] = [] + accumulated: list[tuple[float, float]] = [] + cum_contracts = 0.0 + for lv in levels: + row = dict(lv) + is_first = row.get("leg_key") == "first" or row.get("label") == "首仓" or row.get("i") == 0 + row_cum = cum_contracts + if is_first: + try: + amt_f = float(row.get("contracts") if row.get("contracts") is not None else first_amt) + except (TypeError, ValueError): + amt_f = first_amt + if first_done: + fill_px = trend_leg_display_price(p, 0) + if fill_px is None: + try: + fill_px = float(p.get("avg_entry_price") or ref) + except (TypeError, ValueError): + fill_px = ref + accumulated = [(float(fill_px), amt_f)] + cum_contracts = amt_f + row_cum = cum_contracts + row["price"] = fill_px + if target_avg is not None and legs_done == 0: + row["avg_entry"] = target_avg + else: + row["avg_entry"] = float(fill_px) + else: + accumulated = [(ref, amt_f)] + cum_contracts = amt_f + row_cum = cum_contracts + row["avg_entry"] = ref + else: + try: + leg_num = int(row.get("i") or 0) + except (TypeError, ValueError): + leg_num = 0 + grid_trigger = row.get("price") + try: + grid_trigger_f = float(grid_trigger) if grid_trigger is not None else None + except (TypeError, ValueError): + grid_trigger_f = None + try: + leg_contracts = float(row.get("contracts") or 0) + except (TypeError, ValueError): + leg_contracts = 0.0 + done = row.get("status") == "done" or (leg_num > 0 and leg_num <= legs_done) + if done and leg_contracts > 0: + fill_px = trend_leg_display_price(p, leg_num) + if fill_px is None: + fill_px = grid_trigger_f if grid_trigger_f is not None else ref + row["price"] = fill_px + accumulated.append((fill_px, leg_contracts)) + cum_contracts += leg_contracts + row_cum = cum_contracts + if leg_num == legs_done and target_avg is not None: + row["avg_entry"] = target_avg + else: + avg = weighted_avg_entry(accumulated) + if avg is not None: + row["avg_entry"] = avg + elif grid_trigger_f is not None and leg_contracts > 0: + row["price"] = grid_trigger_f + projected = accumulated + [(grid_trigger_f, leg_contracts)] + avg = weighted_avg_entry(projected) + if avg is not None: + row["avg_entry"] = avg + row_cum = cum_contracts + leg_contracts + elif grid_trigger_f is not None: + row["price"] = grid_trigger_f + + avg_entry = row.get("avg_entry") + if avg_entry is not None and row_cum > 0: + profit_u = calc_tp_profit_usdt( + direction, float(avg_entry), user_tp, row_cum, contract_size + ) + row["take_profit_price"] = user_tp + row["profit_u"] = profit_u + row["risk_u"] = risk_u + row["rr"] = calc_money_reward_risk_ratio(profit_u, risk_u) if profit_u is not None else None + row["take_profit"] = profit_u + row["stop_loss"] = risk_u + row["stop_loss_price"] = sl + out.append(row) + return out diff --git a/strategy_trend_register.py b/lib/strategy/strategy_trend_register.py similarity index 95% rename from strategy_trend_register.py rename to lib/strategy/strategy_trend_register.py index 7d18255..ed5856c 100644 --- a/strategy_trend_register.py +++ b/lib/strategy/strategy_trend_register.py @@ -1,1914 +1,1914 @@ -"""趋势回调:路由、轮询、页面数据(四所共用,依赖各 app 模块交易所能力)。""" -from __future__ import annotations - -import inspect -import json -import os -import time -import uuid -from typing import Any, Optional - -from flask import Flask, flash, redirect, request, url_for -from jinja2 import ChoiceLoader, FileSystemLoader - -from strategy_config import resolve_trading_app_module -from strategy_db import init_strategy_tables -from strategy_trend_exchange import ( - cancel_symbol_orders, - trend_market_add, - trend_market_close, - trend_refresh_stop_only, - trend_replace_tpsl, -) -from strategy_trend_lib import ( - build_grid_prices, - build_leg_amounts_json, - calc_risk_fraction, - trend_dca_level_reached, - trend_effective_margin_capital, - validate_trend_bounds, -) -from strategy_trade_labels import ( - ENTRY_REASON_TREND_PULLBACK, - MONITOR_TYPE_TREND_PULLBACK, - TREND_HANDOFF_KEY_SIGNAL, - TREND_HANDOFF_TRADE_NOTE, -) - -MONITOR_TYPE_TREND = MONITOR_TYPE_TREND_PULLBACK - -# 趋势回调:交易所报空仓需连续 N 次轮询确认,避免 OKX 等 API 瞬时误判立即结束计划 -_TREND_FLAT_STREAK: dict[int, int] = {} -TREND_FLAT_CONFIRM_POLLS = max(1, int(os.getenv("TREND_FLAT_CONFIRM_POLLS", "5"))) -TREND_OPEN_GRACE_SEC = max(0, int(os.getenv("TREND_OPEN_GRACE_SEC", "180"))) -_TREND_LIVE_SKIP_LOG_TS = 0.0 -_TREND_POLL_STATE: dict[str, Any] = { - "updated_at": None, - "live_ok": True, - "live_reason": "", - "plans": {}, -} - - -def get_trend_poll_state() -> dict: - return dict(_TREND_POLL_STATE or {}) - - -def _log_trend_live_skip(reason: str) -> None: - global _TREND_LIVE_SKIP_LOG_TS - now = time.time() - if now - _TREND_LIVE_SKIP_LOG_TS < 60: - return - _TREND_LIVE_SKIP_LOG_TS = now - print(f"[trend_pullback] poll skipped (live not ready): {reason}", flush=True) - - -def _set_trend_poll_plan(plan_id: int, info: dict) -> None: - plans = dict(_TREND_POLL_STATE.get("plans") or {}) - plans[str(plan_id)] = info - _TREND_POLL_STATE["plans"] = plans - - -def summarize_trend_dca_probe(cfg: dict, row) -> dict: - """诊断单计划为何未补仓(供页面 / API)。""" - m = _m(cfg) - d = _row(cfg, row) - plan_id = int(d.get("id") or 0) - sym = d.get("symbol") or "" - direction = (d.get("direction") or "long").lower() - ex_sym = d.get("exchange_symbol") or m.normalize_exchange_symbol(sym) - out: dict[str, Any] = { - "plan_id": plan_id, - "symbol": sym, - "mark_price": None, - "next_trigger": None, - "trigger_reached": False, - "legs_done": int(d.get("legs_done") or 0), - "first_order_done": int(d.get("first_order_done") or 0), - "block_reason": None, - } - try: - legs_done = int(d.get("legs_done") or 0) - grid = json.loads(d.get("grid_prices_json") or "[]") - if not isinstance(grid, list): - grid = [] - leg_amounts = json.loads(d.get("leg_amounts_json") or "[]") - if not isinstance(leg_amounts, list): - leg_amounts = [] - except Exception: - grid = [] - leg_amounts = [] - legs_done = 0 - pf = _trend_poll_price(m, sym, ex_sym, direction) - out["mark_price"] = pf - ok_live, live_reason = m.ensure_exchange_live_ready() - out["live_ok"] = ok_live - if not ok_live: - out["block_reason"] = live_reason or "实盘未就绪" - if not int(d.get("first_order_done") or 0): - out["block_reason"] = out["block_reason"] or "首仓未完成" - return out - if legs_done >= len(grid) or legs_done >= len(leg_amounts): - out["block_reason"] = out["block_reason"] or "补仓档已全部完成或无 grid" - return out - try: - level = float(grid[legs_done]) - except (TypeError, ValueError, IndexError): - out["block_reason"] = out["block_reason"] or "无效补仓触发价" - return out - out["next_trigger"] = level - if pf is None: - out["block_reason"] = out["block_reason"] or "无法读取标记价" - return out - reached = trend_dca_level_reached(direction, float(pf), level) - out["trigger_reached"] = reached - if reached and not ok_live: - out["block_reason"] = live_reason or "LIVE_TRADING_ENABLED=false" - elif reached and ok_live: - pos = m.get_live_position_contracts(ex_sym, direction) - try: - local_open = float(d.get("order_amount_open") or 0) - except (TypeError, ValueError): - local_open = 0.0 - if pos is None and local_open > 0: - pos = local_open - if pos is None: - out["block_reason"] = "无法读取交易所持仓" - elif float(pos) <= 0: - out["block_reason"] = "交易所无持仓" - else: - out["block_reason"] = ( - "标记价已触达,轮询应自动下单;若仍未补请确认 PM2 进程 crypto_gate_bot " - "(非 manual-agent-gate-bot)在运行,并查看 pm2 logs crypto_gate_bot" - ) - elif not reached: - out["block_reason"] = f"标记价 {pf} 未触达下一档 {level}" - return out - - -def trend_add_zone_label(direction: str) -> str: - return "补仓下沿" if (direction or "long").strip().lower() == "short" else "补仓上沿" - - -def install_strategy_trend(app: Flask, repo_root: str, app_module: Any = None, **build_kw) -> dict: - from strategy_register import attach_strategy_templates - - attach_strategy_templates(app, repo_root) - cfg = build_trend_config(app_module, **build_kw) - app.extensions["strategy_trend_cfg"] = cfg - register_trend_routes(app, cfg) - _patch_hub_monitor_enrich(app, cfg) - roll_cfg = app.extensions.get("strategy_roll_cfg") - if isinstance(roll_cfg, dict): - from strategy_roll_ui_lib import patch_roll_hub_enrich - - patch_roll_hub_enrich(app, roll_cfg) - _patch_hub_trend_views(app) - - @app.context_processor - def _trend_ctx(): - return {"trend_add_zone_label": trend_add_zone_label} - - return cfg - - -def build_trend_config(app_module: Any = None, **kw) -> dict[str, Any]: - m = resolve_trading_app_module(app_module) - dca = max(1, int(os.getenv("TREND_PULLBACK_DCA_LEGS", kw.get("dca_legs", "5")))) - preview_ttl = max(10, int(os.getenv("TREND_PULLBACK_PREVIEW_TTL_SECONDS", "120"))) - drift = float(os.getenv("TREND_PREVIEW_MAX_BALANCE_DRIFT_PCT", "5")) - be_pct = float(os.getenv("TREND_PULLBACK_MANUAL_BREAKEVEN_OFFSET_PCT", "0.3")) - buf = float(getattr(m, "FULL_MARGIN_BUFFER_RATIO", 0.95)) - - def amount_precise(ex_sym, amt): - fn = getattr(m, "_safe_amount_to_precision", None) - if callable(fn): - return fn(ex_sym, amt) - try: - m.ensure_markets_loaded() - return float(m.exchange.amount_to_precision(ex_sym, float(amt))) - except Exception: - return None - - def send_wechat(content): - fn = getattr(m, "send_wechat_msg", None) - if callable(fn): - fn(content) - - def wechat_account_label(): - fn = getattr(m, "_wechat_account_label", None) - if callable(fn): - try: - return fn() - except Exception: - pass - return getattr(m, "EXCHANGE_DISPLAY_NAME", "") or "" - - def wechat_direction_text(direction): - fn = getattr(m, "_wechat_direction_text", None) - if callable(fn): - try: - return fn(direction) - except Exception: - pass - d = (direction or "long").strip().lower() - return "做多" if d == "long" else "做空" - - return { - "app_module": m, - "exchange_display": getattr(m, "EXCHANGE_DISPLAY_NAME", ""), - "login_required": m.login_required, - "get_db": m.get_db, - "row_to_dict": m.row_to_dict, - "dca_legs": dca, - "preview_ttl": preview_ttl, - "drift_pct": drift, - "breakeven_offset_pct": be_pct, - "margin_buffer": buf, - "amount_precise": amount_precise, - "max_active_positions": int(getattr(m, "MAX_ACTIVE_POSITIONS", 1)), - "reset_hour": int(getattr(m, "TRADING_DAY_RESET_HOUR", 8)), - "monitor_type_trend": MONITOR_TYPE_TREND, - "send_wechat": send_wechat, - "format_price": getattr(m, "format_price_for_symbol", None), - "wechat_account_label": wechat_account_label, - "wechat_direction_text": wechat_direction_text, - } - - -def _m(cfg: dict): - return cfg["app_module"] - - -def _row(cfg, row) -> dict: - return cfg["row_to_dict"](row) - - -def precheck_trend_start(cfg: dict, conn) -> tuple[bool, str]: - m = _m(cfg) - mode = getattr(m, "POSITION_SIZING_MODE", None) or "risk" - try: - from position_sizing_lib import OPEN_SOURCE_TREND, assert_open_source_allowed - - ok_src, src_msg = assert_open_source_allowed(mode, OPEN_SOURCE_TREND) - if not ok_src: - return False, src_msg - except Exception: - pass - now = m.app_now() - if not m.trading_day_reset_allows_new_open(now): - return False, f"北京时间 {cfg['reset_hour']}:00 前不允许持仓" - active = m.get_active_position_count(conn) - if active >= cfg["max_active_positions"]: - return ( - False, - f"已达最大持仓数({active}/{cfg['max_active_positions']})," - "请先结束「实盘下单」中的持仓,再启动趋势回调", - ) - trend_n = conn.execute( - "SELECT COUNT(*) FROM trend_pullback_plans WHERE status='active'" - ).fetchone()[0] - if int(trend_n or 0) > 0: - return False, "已存在运行中的趋势回调计划" - return True, "" - - -def _cleanup_stale_previews(conn) -> None: - ms = int(time.time() * 1000) - stale = conn.execute( - "SELECT id FROM trend_pullback_previews WHERE expires_at_ms < ?", (ms,) - ).fetchall() - for row in stale: - try: - conn.execute( - "UPDATE trend_pullback_preview_snapshots SET outcome='expired' " - "WHERE preview_id=? AND outcome='open'", - (row["id"],), - ) - except Exception: - pass - conn.execute("DELETE FROM trend_pullback_previews WHERE expires_at_ms < ?", (ms,)) - - -def parse_trend_plan(cfg: dict, form_dict) -> tuple[Optional[dict], Optional[str]]: - m = _m(cfg) - d = form_dict or {} - symbol = m.normalize_symbol_input(d.get("symbol")) - if not symbol: - return None, "symbol 不能为空" - direction = (d.get("direction") or "long").strip().lower() - if direction not in ("long", "short"): - return None, "方向错误" - try: - stop_loss = float(d.get("sl")) - add_upper = float(d.get("add_upper")) - take_profit = float(d.get("take_profit")) - risk_percent = float(d.get("risk_percent") or "5") - except Exception: - return None, "价格或风险比例格式错误" - try: - lev_raw = m.parse_positive_float(d.get("leverage")) - leverage = int(lev_raw) if lev_raw is not None else m.infer_leverage(symbol) - except Exception: - return None, "杠杆格式错误" - if leverage <= 0 or risk_percent <= 0: - return None, "杠杆与风险比例必须大于0" - bound_err = validate_trend_bounds(direction, stop_loss, add_upper) - if bound_err: - return None, bound_err - snap = m.get_available_trading_usdt() - if snap is None or snap <= 0: - return None, "无法读取合约账户 USDT 可用余额,请检查 API 与账户类型" - live_price = m.get_price(symbol) - if live_price is None: - return None, "获取实时价格失败" - exchange_symbol = m.normalize_exchange_symbol(symbol) - rf = calc_risk_fraction(direction, add_upper, stop_loss) - if rf is None or rf <= 0: - return None, "止损与补仓区间边界组合无法计算风险比例" - risk_budget = float(snap) * (risk_percent / 100.0) - notional = risk_budget / rf - margin_plan = notional / float(leverage) - margin_plan = min(margin_plan, float(snap) * cfg["margin_buffer"]) - if margin_plan <= 0: - return None, "计划保证金过小" - try: - target_amt, _ = m.prepare_order_amount(exchange_symbol, margin_plan, leverage, live_price) - except Exception as e: - return None, str(e) - ap = cfg["amount_precise"] - first_amt = ap(exchange_symbol, float(target_amt) * 0.5) - if first_amt is None or first_amt <= 0: - return None, "首仓张数过小(低于交易所最小张数),请提高风险比例或杠杆" - remainder_total = ap(exchange_symbol, max(0.0, float(target_amt) - float(first_amt))) - if remainder_total is None: - remainder_total = 0.0 - m.ensure_markets_loaded() - market = m.exchange.market(exchange_symbol) - min_amt = float((market.get("limits", {}).get("amount", {}) or {}).get("min") or 0) - n_legs, leg_json, per_ref = build_leg_amounts_json( - exchange_symbol, remainder_total, cfg["dca_legs"], ap, min_amt - ) - if n_legs <= 0: - return None, "剩余计划张数不足以拆出补仓档,请提高风险比例或放宽止损与补仓区间间距" - grid = build_grid_prices(direction, stop_loss, add_upper, n_legs) - if len(grid) != n_legs: - return None, "补仓网格生成失败" - opened_at = m.app_now_str() - try: - leg_list = json.loads(leg_json) - except Exception: - leg_list = [] - contract_size = float(market.get("contractSize") or 1) - return { - "symbol": symbol, - "exchange_symbol": exchange_symbol, - "direction": direction, - "leverage": leverage, - "stop_loss": stop_loss, - "add_upper": add_upper, - "take_profit": take_profit, - "risk_percent": risk_percent, - "snapshot_available_usdt": float(snap), - "snapshot_at": opened_at, - "live_price_ref": float(live_price), - "plan_margin_capital": float(margin_plan), - "target_order_amount": float(target_amt), - "first_order_amount": float(first_amt), - "remainder_total": float(remainder_total), - "dca_legs": int(n_legs), - "per_leg_amount": float(per_ref), - "grid_prices_json": json.dumps(grid), - "leg_amounts_json": leg_json, - "grid": grid, - "leg_amounts": leg_list, - "contract_size": contract_size, - }, None - - -def _insert_preview_snapshot(conn, preview_id: str, created: str, exp_ms: int, pl: dict) -> None: - conn.execute( - """INSERT INTO trend_pullback_preview_snapshots ( - preview_id,symbol,exchange_symbol,direction,leverage,stop_loss,add_upper,take_profit,risk_percent, - snapshot_available_usdt,snapshot_at,live_price_ref,plan_margin_capital,target_order_amount,first_order_amount,remainder_total, - dca_legs,per_leg_amount,grid_prices_json,leg_amounts_json,expires_at_ms,preview_created_at - ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", - ( - preview_id, - pl["symbol"], - pl["exchange_symbol"], - pl["direction"], - pl["leverage"], - pl["stop_loss"], - pl["add_upper"], - pl["take_profit"], - pl["risk_percent"], - pl["snapshot_available_usdt"], - pl["snapshot_at"], - pl["live_price_ref"], - pl["plan_margin_capital"], - pl["target_order_amount"], - pl["first_order_amount"], - pl["remainder_total"], - pl["dca_legs"], - pl["per_leg_amount"], - pl["grid_prices_json"], - pl["leg_amounts_json"], - exp_ms, - created, - ), - ) - - -def _format_trend_price(cfg: dict, symbol: str, value) -> str: - if value in (None, ""): - return "—" - m = _m(cfg) - sym = symbol or "" - norm = getattr(m, "normalize_exchange_symbol", None) - if callable(norm): - try: - sym = norm(sym) or sym - except Exception: - pass - try: - m.ensure_markets_loaded() - return str(m.exchange.price_to_precision(sym, float(value))) - except Exception: - fn = getattr(m, "format_price_for_symbol", None) - if callable(fn): - return fn(symbol, value) - return str(value) - - -def _trend_add_leg_fields(cfg: dict, d: dict) -> dict: - """解析已补仓次数与已触达网格价(供策略页与中控 monitor 共用)。""" - import json - - out = dict(d) - try: - legs_done = int(out.get("legs_done") or 0) - except (TypeError, ValueError): - legs_done = 0 - try: - dca_legs = int(out.get("dca_legs") or 0) - except (TypeError, ValueError): - dca_legs = 0 - try: - grid = json.loads(out.get("grid_prices_json") or "[]") - if not isinstance(grid, list): - grid = [] - except Exception: - grid = [] - add_prices: list[float] = [] - try: - from strategy_trend_lib import trend_leg_display_price - - for i in range(1, legs_done + 1): - px = trend_leg_display_price(out, i) - if px is not None: - add_prices.append(float(px)) - except Exception: - pass - if not add_prices: - for x in grid[:legs_done]: - try: - add_prices.append(float(x)) - except (TypeError, ValueError): - pass - sym = out.get("exchange_symbol") or out.get("symbol") or "" - out["add_count"] = legs_done - out["add_count_total"] = dca_legs - out["add_prices"] = add_prices - out["add_prices_display"] = [_format_trend_price(cfg, sym, p) for p in add_prices] - for field in ("stop_loss", "take_profit", "add_upper", "avg_entry_price"): - if out.get(field) not in (None, ""): - out[f"{field}_display"] = _format_trend_price(cfg, sym, out.get(field)) - return out - - -def enrich_trend_plan_for_hub(cfg: dict, raw: dict) -> dict: - """中控 /api/hub/monitor:与策略页运行中计划卡片同字段(浮盈亏、标记价、盈亏比等)。""" - d = enrich_trend_plan(cfg, dict(raw or {})) - d["monitor_source"] = "趋势回调计划" - m = _m(cfg) - try: - snap = float(d.get("snapshot_available_usdt") or 0) - margin = float(d.get("plan_margin_capital") or 0) - if snap > 0 and margin > 0: - d["position_ratio_pct"] = round(margin / snap * 100.0, 2) - except (TypeError, ValueError): - pass - return d - - -def _patch_hub_trend_views(app: Flask) -> None: - """将趋势回调路由注册进 HUB_CTX.views,供中控 /api/hub/trend/* 调用。""" - ctx = dict(app.config.get("HUB_CTX") or {}) - views = dict(ctx.get("views") or {}) - for name in ( - "preview_trend_pullback", - "execute_trend_pullback", - "stop_trend_pullback", - "trend_pullback_breakeven", - ): - vf = app.view_functions.get(name) - if vf is not None: - views[name] = vf - ctx["views"] = views - app.config["HUB_CTX"] = ctx - - -def patch_trend_hub_enrich(app: Flask, cfg: dict) -> None: - """hub_bridge install 之后调用:四所 /api/hub/monitor 趋势字段与策略页一致。""" - _patch_hub_monitor_enrich(app, cfg) - - -def _patch_hub_monitor_enrich(app: Flask, cfg: dict) -> None: - ctx = dict(app.config.get("HUB_CTX") or {}) - prev = ctx.get("enrich_monitor") - - def enrich_monitor(keys=None, orders=None, trends=None, rolls=None): - payload: dict[str, Any] = {} - if callable(prev): - try: - prev_out = prev(keys=keys, orders=orders, trends=trends, rolls=rolls) - if isinstance(prev_out, dict): - payload.update(prev_out) - except Exception: - pass - if trends: - payload["trends"] = [ - enrich_trend_plan_for_hub(cfg, t) for t in trends if isinstance(t, dict) - ] - return payload - - ctx["enrich_monitor"] = enrich_monitor - app.config["HUB_CTX"] = ctx - - -def enrich_trend_plan(cfg: dict, row) -> dict: - m = _m(cfg) - d = _row(cfg, row) - try: - d["breakeven_applied"] = int(d.get("breakeven_applied") or 0) != 0 - except Exception: - d["breakeven_applied"] = False - ex_sym = d.get("exchange_symbol") or m.normalize_exchange_symbol(d.get("symbol") or "") - direction = (d.get("direction") or "long").lower() - metrics_fn = getattr(m, "get_live_position_exchange_metrics", None) - met = None - if callable(metrics_fn): - try: - lev = int(d.get("leverage") or 0) or None - except (TypeError, ValueError): - lev = None - try: - met = metrics_fn(ex_sym, direction, order_leverage=lev) - except TypeError: - met = metrics_fn(ex_sym, direction) - if met and met.get("entry_price") is not None: - try: - live_entry = float(met["entry_price"]) - if live_entry > 0: - d["avg_entry_price"] = live_entry - except (TypeError, ValueError): - pass - if met and met.get("unrealized_pnl") is not None: - d["floating_pnl"] = float(met["unrealized_pnl"]) - elif ( - met - and met.get("mark_price") is not None - and d.get("avg_entry_price") is not None - ): - try: - from hub_position_metrics import estimate_linear_swap_upnl_usdt - - entry = float(d["avg_entry_price"]) - mark = float(met["mark_price"]) - qty = None - cs = 1.0 - get_qty = getattr(m, "get_live_position_contracts", None) - get_cs = getattr(m, "get_contract_size", None) - if callable(get_qty): - qty = get_qty(ex_sym, direction) - if callable(get_cs): - cs = float(get_cs(ex_sym)) - upnl = estimate_linear_swap_upnl_usdt( - direction, entry, mark, qty, cs - ) - d["floating_pnl"] = float(upnl) if upnl is not None else None - except (TypeError, ValueError): - d["floating_pnl"] = None - else: - d["floating_pnl"] = None - if met and met.get("mark_price") is not None: - d["floating_mark"] = float(met["mark_price"]) - else: - d["floating_mark"] = None - else: - d["floating_pnl"] = d["floating_mark"] = None - get_cs = getattr(m, "get_contract_size", None) - if callable(get_cs): - try: - d["contract_size"] = float(get_cs(ex_sym)) - except (TypeError, ValueError): - pass - d = _trend_add_leg_fields(cfg, d) - from strategy_snapshot_lib import attach_trend_dca_levels - from strategy_trend_lib import calc_trend_plan_money_metrics - - d = attach_trend_dca_levels(d) - money = calc_trend_plan_money_metrics(d) - if money.get("money_rr") is not None: - d["money_rr"] = money["money_rr"] - d["planned_rr"] = money["money_rr"] - if money.get("risk_amount_u") is not None: - d["risk_amount_u"] = money["risk_amount_u"] - try: - d["breakeven_default_offset_pct"] = float(cfg.get("breakeven_offset_pct", 0.3)) - except (TypeError, ValueError): - d["breakeven_default_offset_pct"] = 0.3 - return d - - -def _weighted_avg(old_avg, old_amt, fill_px, add_amt): - try: - oa, aa = float(old_amt), float(add_amt) - if oa <= 0: - return float(fill_px) - return (float(old_avg) * oa + float(fill_px) * aa) / (oa + aa) - except Exception: - return float(fill_px or 0) - - -def _plan_stop_status(result_label: str) -> str: - if result_label == "止盈": - return "stopped_tp" - if result_label == "止损": - return "stopped_sl" - return "stopped_manual" - - -def _call_insert_trade_record(m, plan_id: int, kwargs: dict) -> None: - """按各所 insert_trade_record 签名过滤参数,避免未知字段导致记账失败。""" - fn = getattr(m, "insert_trade_record", None) - if not callable(fn): - raise RuntimeError("app_module 缺少 insert_trade_record") - allowed = set(inspect.signature(fn).parameters.keys()) - call = {k: v for k, v in kwargs.items() if k in allowed} - if "trend_plan_id" in allowed: - call["trend_plan_id"] = int(plan_id) - fn(**call) - - -def _best_trend_close_snapshot(conn, plan_id: int) -> dict | None: - from strategy_snapshot_lib import ( - FINAL_TREND_CLOSE_LABELS, - STRATEGY_TREND, - _final_trend_close_rank, - ) - - rows = conn.execute( - f"""SELECT * FROM strategy_trade_snapshots - WHERE strategy_type=? AND source_id=? - AND result_label IN ({",".join("?" * len(FINAL_TREND_CLOSE_LABELS))})""", - (STRATEGY_TREND, int(plan_id), *FINAL_TREND_CLOSE_LABELS), - ).fetchall() - if not rows: - return None - parsed = [_row_dict(row) for row in rows] - return max( - parsed, - key=lambda d: ( - _final_trend_close_rank(str(d.get("result_label") or "")), - int(d.get("id") or 0), - ), - ) - - -def _ensure_trend_plan_trade_record( - cfg: dict, conn, plan_id: int, *, prefer_label: str = "手动平仓" -) -> bool: - """计划已结束但 trade_records 缺失时,从策略快照补录一条。""" - if _trend_plan_trade_exists(conn, plan_id): - return True - m = _m(cfg) - plan = conn.execute( - "SELECT * FROM trend_pullback_plans WHERE id=?", (int(plan_id),) - ).fetchone() - if not plan: - return False - plan_d = _row_dict(plan) - snap = _best_trend_close_snapshot(conn, plan_id) - if not snap: - return False - try: - payload = json.loads(snap.get("snapshot_json") or "{}") - except Exception: - payload = {} - sym = snap.get("symbol") or plan_d.get("symbol") or payload.get("symbol") - direction = snap.get("direction") or plan_d.get("direction") or "long" - result = (prefer_label or "").strip() or (snap.get("result_label") or "").strip() or "手动平仓" - opened_at = snap.get("opened_at") or plan_d.get("opened_at") - closed_at = snap.get("closed_at") - pnl_amount = snap.get("pnl_amount") - if pnl_amount is None: - pnl_amount = payload.get("pnl_amount") - avg_e = float(payload.get("avg_entry_price") or plan_d.get("avg_entry_price") or 0) - margin_cap = trend_effective_margin_capital(plan_d) - lev = int(plan_d.get("leverage") or 1) - hold_seconds = m.calc_hold_seconds( - opened_at or "", - m.parse_dt_for_trading_day(closed_at) or m.app_now(), - ) - res = m.normalize_result_with_pnl(result, float(pnl_amount or 0)) - risk_amt = m.calc_risk_amount_from_plan( - direction, - float(plan_d.get("add_upper") or 0), - float(plan_d.get("stop_loss") or 0), - float(plan_d.get("plan_margin_capital") or 0), - lev, - ) - planned_rr = m.calc_rr_ratio( - direction, - avg_e, - float(plan_d.get("stop_loss") or 0), - float(plan_d.get("take_profit") or 0), - ) - session_date = plan_d.get("session_date") or m.get_trading_day() - _bump_session_capital_no_commit(m, conn, session_date, float(pnl_amount or 0)) - _call_insert_trade_record( - m, - plan_id, - dict( - conn=conn, - symbol=sym, - monitor_type=MONITOR_TYPE_TREND, - direction=direction, - trigger_price=avg_e, - stop_loss=float(plan_d.get("stop_loss") or 0), - initial_stop_loss=float(plan_d.get("initial_stop_loss") or plan_d.get("stop_loss") or 0), - take_profit=float(plan_d.get("take_profit") or 0), - margin_capital=margin_cap, - leverage=lev, - pnl_amount=pnl_amount, - hold_seconds=hold_seconds, - trade_style="trend_pullback", - risk_amount=risk_amt, - planned_rr=planned_rr, - actual_rr=m.calc_actual_rr(pnl_amount, risk_amt), - result=res, - opened_at=opened_at, - closed_at=closed_at, - entry_reason=ENTRY_REASON_TREND_PULLBACK, - ), - ) - conn.commit() - return True - - -def sync_trend_plans_after_external_close( - cfg: dict, conn, symbol: str, direction: str -) -> dict[str, Any]: - """中控/外部全平后:结束仍 active 的同币种同向趋势计划(避免监控再记一条止损)。""" - m = _m(cfg) - sym = m.normalize_symbol_input(symbol) if hasattr(m, "normalize_symbol_input") else (symbol or "").strip() - if not sym: - return {"ok": False, "msg": "symbol 无效", "finalized": 0} - direction = (direction or "long").strip().lower() - rows = conn.execute( - "SELECT * FROM trend_pullback_plans WHERE status='active' AND symbol=? AND direction=?", - (sym, direction), - ).fetchall() - finalized = 0 - for row in rows: - px = m.get_price(row["symbol"]) - exit_p = float(px) if px is not None else 0.0 - before = _trend_plan_trade_exists(conn, int(row["id"])) - _finalize_plan(cfg, conn, row, "手动平仓", exit_p) - if not before: - finalized += 1 - return {"ok": True, "finalized": finalized, "symbol": sym, "direction": direction} - - -def _trend_plan_trade_exists(conn, plan_id: int) -> bool: - try: - return conn.execute( - "SELECT id FROM trade_records WHERE trend_plan_id=? LIMIT 1", - (int(plan_id),), - ).fetchone() is not None - except Exception: - return False - - -def _bump_session_capital_no_commit( - m, conn, session_date: str, pnl_amount: float -) -> float | None: - """更新当日资金,不单独 commit(与 _finalize_plan 同一事务)。""" - try: - row = conn.execute( - "SELECT current_capital FROM trading_sessions WHERE session_date = ?", - (session_date,), - ).fetchone() - if not row: - start_cap = float(getattr(m, "DAILY_START_CAPITAL", 0) or 0) - if start_cap <= 0: - ensure = getattr(m, "ensure_session", None) - if callable(ensure): - ensured = ensure(conn, session_date) - row = ensured - else: - return None - else: - conn.execute( - "INSERT OR IGNORE INTO trading_sessions " - "(session_date, start_capital, current_capital) VALUES (?,?,?)", - (session_date, start_cap, start_cap), - ) - row = conn.execute( - "SELECT current_capital FROM trading_sessions WHERE session_date = ?", - (session_date,), - ).fetchone() - if not row: - return None - new_capital = float(row["current_capital"]) + float(pnl_amount) - conn.execute( - "UPDATE trading_sessions SET current_capital = ?, updated_at = CURRENT_TIMESTAMP " - "WHERE session_date = ?", - (round(new_capital, 4), session_date), - ) - return round(new_capital, 4) - except Exception: - return None - - -def _apply_trend_user_risk_close(cfg: dict, conn, *, trade_record_id=None, closed_at_ms=None) -> None: - m = _m(cfg) - fn = getattr(m, "hub_user_initiated_close", None) - from account_risk_lib import CLOSE_SOURCE_USER_TREND_STOP - - if callable(fn): - fn( - conn, - source=CLOSE_SOURCE_USER_TREND_STOP, - count=1, - trade_record_id=trade_record_id, - closed_at_ms=closed_at_ms, - ) - return - from account_risk_lib import on_user_initiated_close - - on_user_initiated_close( - conn, - source=CLOSE_SOURCE_USER_TREND_STOP, - trade_record_id=trade_record_id, - closed_at_ms=closed_at_ms, - trading_day=m.get_trading_day(), - now=m.app_now(), - count=1, - ) - - -def _finalize_plan(cfg: dict, conn, row, result_label: str, exit_price: float, *, user_initiated_risk: bool = False) -> None: - m = _m(cfg) - plan_id = int(row["id"]) - active = conn.execute( - "SELECT * FROM trend_pullback_plans WHERE id=? AND status='active'", - (plan_id,), - ).fetchone() - if not active: - return - row = active - sym = row["symbol"] - direction = row["direction"] or "long" - ex_sym = row["exchange_symbol"] or m.normalize_exchange_symbol(sym) - closed_at = m.app_now_str() - opened_at = row["opened_at"] or closed_at - hold_seconds = m.calc_hold_seconds(opened_at, m.parse_dt_for_trading_day(closed_at) or m.app_now()) - plan_margin = float(row["plan_margin_capital"] or 0) - margin_cap = trend_effective_margin_capital(_row(cfg, row)) - lev = int(row["leverage"] or 1) - avg_e = float(row["avg_entry_price"] or 0) - pnl_amount = m.calc_pnl(direction, avg_e, float(exit_price), margin_cap, lev) - res = m.normalize_result_with_pnl(result_label, pnl_amount) - risk_amt = m.calc_risk_amount_from_plan( - direction, float(row["add_upper"]), float(row["stop_loss"]), plan_margin, lev - ) - try: - target = float(row["target_order_amount"] or 0) - open_amt = float(row["order_amount_open"] or 0) - if risk_amt is not None and target > 0 and open_amt > 0: - risk_amt = round(float(risk_amt) * min(1.0, open_amt / target), 6) - except (TypeError, ValueError): - pass - planned_rr = m.calc_rr_ratio(direction, avg_e, float(row["stop_loss"]), float(row["take_profit"])) - st = _plan_stop_status(result_label) - cur = conn.execute( - "UPDATE trend_pullback_plans SET status=?, message=? WHERE id=? AND status='active'", - (st, res, plan_id), - ) - if not getattr(cur, "rowcount", 0): - return - try: - from strategy_snapshot_lib import save_trend_plan_snapshot - - save_trend_plan_snapshot( - cfg, - conn, - row, - result_label=result_label, - exit_price=float(exit_price) if exit_price is not None else None, - pnl_amount=float(pnl_amount) if pnl_amount is not None else None, - closed_at=closed_at, - ) - except Exception: - pass - try: - cancel_symbol_orders(cfg, ex_sym) - except Exception: - pass - session_capital = None - trade_record_id = None - if not _trend_plan_trade_exists(conn, plan_id): - session_date = row["session_date"] or m.get_trading_day() - session_capital = _bump_session_capital_no_commit( - m, conn, session_date, pnl_amount - ) - _call_insert_trade_record( - m, - plan_id, - dict( - conn=conn, - symbol=sym, - monitor_type=MONITOR_TYPE_TREND, - direction=direction, - trigger_price=avg_e, - stop_loss=float(row["stop_loss"]), - initial_stop_loss=float(row.get("initial_stop_loss") or row["stop_loss"]), - take_profit=float(row["take_profit"]), - margin_capital=margin_cap, - leverage=lev, - pnl_amount=pnl_amount, - hold_seconds=hold_seconds, - trade_style="trend_pullback", - risk_amount=risk_amt, - planned_rr=planned_rr, - actual_rr=m.calc_actual_rr(pnl_amount, risk_amt), - result=res, - opened_at=opened_at, - closed_at=closed_at, - entry_reason=ENTRY_REASON_TREND_PULLBACK, - ), - ) - try: - from account_risk_lib import insert_trade_record_id - - trade_record_id = insert_trade_record_id(conn) - except Exception: - trade_record_id = None - if user_initiated_risk: - closed_ms = None - to_ms = getattr(m, "_to_ms_with_fallback", None) - if callable(to_ms): - try: - closed_ms = to_ms(None, closed_at) - except Exception: - closed_ms = None - _apply_trend_user_risk_close( - cfg, - conn, - trade_record_id=trade_record_id, - closed_at_ms=closed_ms, - ) - conn.commit() - try: - from strategy_wechat_notify import notify_trend_plan_ended - - notify_trend_plan_ended( - cfg, - plan_id=plan_id, - symbol=sym, - direction=direction, - end_type=result_label, - result_label=res, - exit_price=float(exit_price) if exit_price is not None else None, - pnl_amount=float(pnl_amount) if pnl_amount is not None else None, - ) - except Exception: - pass - extra = getattr(m, "build_wechat_close_message", None) - send = getattr(m, "send_wechat_msg", None) - if callable(extra) and callable(send): - send( - extra( - symbol=sym, - direction=direction, - result=f"{res}({MONITOR_TYPE_TREND})", - pnl_amount=pnl_amount, - hold_seconds=hold_seconds, - trigger_price=avg_e, - current_price=float(exit_price), - stop_loss=float(row["stop_loss"]), - take_profit=float(row["take_profit"]), - close_order_id="-", - extra_note="计划本金口径:启动时合约可用余额快照;止盈由程序监控", - session_capital_fallback=session_capital, - ) - ) - - -def _trend_plan_open_age_sec(row, m) -> float: - opened_ms = None - try: - if "opened_at_ms" in row.keys() and row["opened_at_ms"]: - opened_ms = int(row["opened_at_ms"]) - except Exception: - opened_ms = None - to_ms = getattr(m, "_to_ms_with_fallback", None) - if callable(to_ms): - opened_ms = to_ms(opened_ms, row["opened_at"] if "opened_at" in row.keys() else None) - if opened_ms is None and "opened_at" in row.keys(): - opened_ms = to_ms(None, row["opened_at"]) - if not opened_ms: - return 0.0 - return max(0.0, (time.time() * 1000 - opened_ms) / 1000.0) - - -def _trend_hit_take_profit(direction: str, mark_price: float, take_profit: float, avg_entry: float) -> bool: - try: - pf = float(mark_price) - tp = float(take_profit) - entry = float(avg_entry) - except (TypeError, ValueError): - return False - if entry <= 0 or tp <= 0: - return False - direction = (direction or "long").lower() - if direction == "long": - return tp > entry and pf >= tp - return tp < entry and pf <= tp - - -def _trend_poll_price(m, sym: str, ex_sym: str, direction: str) -> Optional[float]: - """补仓/止盈判定用标记价(与页面「标记价」一致),无标记价时回退 last。""" - fn = getattr(m, "get_symbol_mark_price", None) - if callable(fn): - try: - px = fn(sym) - if px is not None and float(px) > 0: - return float(px) - except Exception: - pass - metrics_fn = getattr(m, "get_live_position_exchange_metrics", None) - if callable(metrics_fn): - try: - met = metrics_fn(ex_sym, direction) - if met and met.get("mark_price") is not None: - px = float(met["mark_price"]) - if px > 0: - return px - except Exception: - pass - px = m.get_price(sym) - try: - return float(px) if px is not None else None - except (TypeError, ValueError): - return None - - -def _should_finalize_trend_flat(row, pos, plan_id: int, m) -> bool: - """首仓后交易所报无仓:需过开仓宽限期 + 连续空仓轮询,避免误判止损。""" - if pos is None: - return False - if float(pos) > 0: - _TREND_FLAT_STREAK.pop(plan_id, None) - return False - if not int(row["first_order_done"] or 0): - return False - age = _trend_plan_open_age_sec(row, m) - if age < TREND_OPEN_GRACE_SEC: - _TREND_FLAT_STREAK.pop(plan_id, None) - return False - try: - local_open = float(row["order_amount_open"] or 0) - except (TypeError, ValueError): - local_open = 0.0 - required = TREND_FLAT_CONFIRM_POLLS - if local_open > 0 and age < TREND_OPEN_GRACE_SEC * 2: - required = max(required, TREND_FLAT_CONFIRM_POLLS * 2) - streak = int(_TREND_FLAT_STREAK.get(plan_id, 0)) + 1 - _TREND_FLAT_STREAK[plan_id] = streak - if streak >= required: - print( - f"[trend_pullback] flat finalize plan={plan_id} sym={row['symbol']} " - f"age={age:.0f}s streak={streak} local_open={local_open}", - flush=True, - ) - return True - return False - - -def check_trend_pullback_plans(cfg: dict) -> None: - m = _m(cfg) - ok_live, live_reason = m.ensure_exchange_live_ready() - _TREND_POLL_STATE["updated_at"] = time.time() - _TREND_POLL_STATE["live_ok"] = ok_live - _TREND_POLL_STATE["live_reason"] = live_reason or "" - if not ok_live: - _log_trend_live_skip(live_reason or "unknown") - conn = cfg["get_db"]() - try: - for row in conn.execute( - "SELECT * FROM trend_pullback_plans WHERE status='active'" - ).fetchall(): - probe = summarize_trend_dca_probe(cfg, row) - if probe.get("trigger_reached"): - _set_trend_poll_plan(int(row["id"]), probe) - except Exception as e: - print(f"[trend_pullback] live-skip probe error: {e}", flush=True) - finally: - conn.close() - return - conn = cfg["get_db"]() - rows = conn.execute( - "SELECT * FROM trend_pullback_plans WHERE status='active'" - ).fetchall() - for row in rows: - try: - plan_id = int(row["id"]) - sym = row["symbol"] - direction = (row["direction"] or "long").lower() - ex_sym = row["exchange_symbol"] or m.normalize_exchange_symbol(sym) - sl = float(row["stop_loss"]) - tp = float(row["take_profit"]) - lev = int(row["leverage"] or 1) - try: - local_open = float(row["order_amount_open"] or 0) - except (TypeError, ValueError): - local_open = 0.0 - pf = _trend_poll_price(m, sym, ex_sym, direction) - if pf is None: - continue - last_p = row["last_mark_price"] - last_pf = float(last_p) if last_p is not None else pf - pos = m.get_live_position_contracts(ex_sym, direction) - if pos is None: - if local_open > 0 and int(row["first_order_done"] or 0): - pos = local_open - else: - continue - elif float(pos) <= 0 and local_open > 0: - age = _trend_plan_open_age_sec(row, m) - if age < TREND_OPEN_GRACE_SEC * 2: - print( - f"[trend_pullback] pos fallback plan={plan_id} sym={sym} " - f"ex_pos=0 local_open={local_open} age={age:.0f}s", - flush=True, - ) - pos = local_open - legs_done = int(row["legs_done"] or 0) - try: - leg_amounts = [float(x) for x in json.loads(row["leg_amounts_json"] or "[]")] - except Exception: - leg_amounts = [] - try: - grid = json.loads(row["grid_prices_json"] or "[]") - except Exception: - grid = [] - avg_e = float(row["avg_entry_price"] or pf or 0) - hit_tp = _trend_hit_take_profit(direction, pf, tp, avg_e) - if hit_tp and pos > 0: - try: - close_resp = trend_market_close(cfg, ex_sym, direction, float(pos), lev) - exit_p = m.extract_trade_price_from_order(close_resp) or pf - except Exception as e: - if not m.is_no_position_error(str(e)): - continue - exit_p = pf - _finalize_plan(cfg, conn, row, "止盈", exit_p) - _TREND_FLAT_STREAK.pop(plan_id, None) - continue - if _should_finalize_trend_flat(row, pos, plan_id, m): - _finalize_plan(cfg, conn, row, "止损", pf) - _TREND_FLAT_STREAK.pop(plan_id, None) - continue - if int(row["first_order_done"] or 0) and legs_done < len(grid) and legs_done < len(leg_amounts): - while legs_done < len(grid) and legs_done < len(leg_amounts): - level = float(grid[legs_done]) - if not trend_dca_level_reached(direction, pf, level): - break - amt = float(m.exchange.amount_to_precision(ex_sym, leg_amounts[legs_done])) - if amt <= 0: - print( - f"[trend_pullback] dca skip plan={plan_id} leg={legs_done + 1} " - f"amt_precision=0 raw={leg_amounts[legs_done]}", - flush=True, - ) - break - try: - add_resp = trend_market_add(cfg, ex_sym, direction, amt, lev) - except Exception as e: - print( - f"[trend_pullback] dca order failed plan={plan_id} sym={sym} " - f"leg={legs_done + 1} level={level} mark={pf} err={e}", - flush=True, - ) - break - fill_px = m.extract_trade_price_from_order(add_resp) or pf - old_avg = float(row["avg_entry_price"] or fill_px) - old_open = float(row["order_amount_open"] or 0) - new_avg = _weighted_avg(old_avg, old_open, fill_px, amt) - legs_done += 1 - from strategy_trend_lib import append_leg_fill_price_json - - fills_json = append_leg_fill_price_json( - row["leg_fill_prices_json"] if "leg_fill_prices_json" in row.keys() else None, - fill_px, - ) - conn.execute( - "UPDATE trend_pullback_plans SET legs_done=?, avg_entry_price=?, " - "order_amount_open=?, last_mark_price=?, leg_fill_prices_json=? WHERE id=?", - (legs_done, new_avg, old_open + amt, pf, fills_json, row["id"]), - ) - row = conn.execute( - "SELECT * FROM trend_pullback_plans WHERE id=?", (row["id"],) - ).fetchone() - print( - f"[trend_pullback] dca filled plan={plan_id} leg={legs_done} " - f"fill={fill_px} avg={new_avg} open={old_open + amt}", - flush=True, - ) - try: - trend_refresh_stop_only(cfg, ex_sym, direction, sl) - except Exception: - pass - conn.execute( - "UPDATE trend_pullback_plans SET last_mark_price=? WHERE id=?", - (pf, row["id"]), - ) - probe = summarize_trend_dca_probe(cfg, row) - probe["last_poll_mark"] = pf - _set_trend_poll_plan(plan_id, probe) - if probe.get("trigger_reached") and probe.get("block_reason"): - print( - f"[trend_pullback] dca blocked plan={plan_id} sym={sym} " - f"mark={pf} next={probe.get('next_trigger')} reason={probe.get('block_reason')}", - flush=True, - ) - except Exception as e: - print( - f"[trend_pullback] poll error plan={row['id'] if row else '?'}: {e}", - flush=True, - ) - continue - conn.commit() - conn.close() - - -TREND_PLAN_STATUS_HANDOFF = "stopped_handoff" - - -def _order_monitor_manual_type(m) -> str: - return getattr(m, "ORDER_MONITOR_TYPE_MANUAL", None) or "下单监控" - - -def _insert_trend_handoff_order_monitor( - cfg: dict, - conn, - plan_row, - *, - new_sl: float, - pos_amt: float, -) -> int: - m = _m(cfg) - sym = plan_row["symbol"] - direction = (plan_row["direction"] or "long").lower() - ex_sym = plan_row["exchange_symbol"] or m.normalize_exchange_symbol(sym) - plan_id = int(plan_row["id"]) - avg_e = float(plan_row["avg_entry_price"] or 0) - tp = float(plan_row["take_profit"] or 0) - lev = int(plan_row["leverage"] or 1) - margin_cap = float(plan_row["plan_margin_capital"] or 0) - init_sl = float( - plan_row["initial_stop_loss"] - if plan_row["initial_stop_loss"] not in (None, "") - else plan_row["stop_loss"] - or 0 - ) - risk_pct = float(plan_row["risk_percent"] or 5) - risk_amt = None - calc_risk = getattr(m, "calc_risk_amount_from_plan", None) - if callable(calc_risk): - try: - risk_amt = calc_risk(direction, avg_e, init_sl, margin_cap, lev) - except Exception: - risk_amt = None - be_rr = float(getattr(m, "BREAKEVEN_RR_TRIGGER", 1) or 1) - be_off = float(getattr(m, "BREAKEVEN_OFFSET_PCT", 0.3) or 0.3) - be_step = float(getattr(m, "BREAKEVEN_STEP_R", 1) or 1) - if direction == "short": - be_price = round(avg_e * (1 - be_off / 100.0), 8) - else: - be_price = round(avg_e * (1 + be_off / 100.0), 8) - rp = getattr(m, "round_price_to_exchange", None) - if callable(rp): - try: - be_price = float(rp(ex_sym, be_price) or be_price) - except Exception: - pass - opened_at = plan_row["opened_at"] or m.app_now_str() - to_ms = getattr(m, "_to_ms_with_fallback", None) - opened_ms = to_ms(plan_row["opened_at_ms"] if "opened_at_ms" in plan_row.keys() else None, opened_at) if callable(to_ms) else None - trading_day = plan_row["session_date"] or getattr(m, "get_trading_day", lambda: None)() - if not trading_day and callable(getattr(m, "get_trading_day", None)): - trading_day = m.get_trading_day() - notional = margin_cap * lev if margin_cap and lev else None - monitor_type = MONITOR_TYPE_TREND_PULLBACK - conn.execute( - "INSERT INTO order_monitors " - "(symbol, exchange_symbol, direction, trigger_price, stop_loss, initial_stop_loss, take_profit, " - "margin_capital, leverage, trade_style, risk_percent, risk_amount, " - "breakeven_rr_trigger, breakeven_offset_pct, breakeven_step_r, breakeven_armed, breakeven_price, " - "breakeven_enabled, notional_value, position_ratio, base_amount, order_amount, exchange_order_id, " - "opened_at, opened_at_ms, session_date, monitor_type, key_signal_type, trend_plan_id) " - "VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", - ( - sym, - ex_sym, - direction, - avg_e, - new_sl, - init_sl, - tp, - margin_cap, - lev, - "trend_pullback_handoff", - risk_pct, - risk_amt, - be_rr, - be_off, - be_step, - 0, - be_price, - 0, - notional, - None, - None, - float(pos_amt), - "", - opened_at, - opened_ms, - trading_day, - monitor_type, - TREND_HANDOFF_KEY_SIGNAL, - plan_id, - ), - ) - new_id = int(conn.execute("SELECT last_insert_rowid()").fetchone()[0]) - persist = getattr(m, "try_persist_exchange_margin_for_order", None) - if callable(persist): - try: - persist(conn, new_id, ex_sym, direction, order_leverage=lev) - except Exception: - pass - return new_id - - -def apply_manual_breakeven(cfg: dict, conn, row, offset_pct=None) -> tuple[bool, Optional[str]]: - """保本:结束趋势计划,持仓移交下单监控(备注趋势回调),交易所同时挂保本止损与止盈。""" - m = _m(cfg) - if (row["status"] or "").strip() != "active": - return False, "计划已结束" - if not int(row["first_order_done"] or 0): - return False, "尚未完成首仓,无法保本" - avg_e = float(row["avg_entry_price"] or 0) - if avg_e <= 0: - return False, "缺少有效持仓均价" - direction = (row["direction"] or "long").lower() - sym = row["symbol"] - ex_sym = row["exchange_symbol"] or m.normalize_exchange_symbol(sym) - pos = m.get_live_position_contracts(ex_sym, direction) - if pos is None or float(pos) <= 0: - return False, "交易所当前无该方向持仓" - pos_amt = float(pos) - dup = conn.execute( - "SELECT id FROM order_monitors WHERE status='active' AND symbol=? AND direction=? LIMIT 1", - (sym, direction), - ).fetchone() - if dup: - return False, "该币种已有运行中的下单监控,请先结束后再保本移交" - be_fn = getattr(m, "calc_trend_manual_breakeven_stop", None) - if not callable(be_fn): - pct = float(offset_pct if offset_pct is not None else cfg["breakeven_offset_pct"]) - if direction == "short": - new_sl_raw = avg_e * (1.0 - pct / 100.0) - else: - new_sl_raw = avg_e * (1.0 + pct / 100.0) - else: - new_sl_raw = be_fn(direction, avg_e, offset_pct) - if new_sl_raw is None: - return False, "保本价计算失败" - new_sl = m.round_price_to_exchange(ex_sym, new_sl_raw) - if new_sl is None: - return False, "保本价经交易所精度舍入后无效" - new_sl = float(new_sl) - tp = float(row["take_profit"] or 0) - if tp <= 0: - return False, "计划止盈价无效" - cur_sl = float(row["stop_loss"] or 0) - if direction == "long": - if new_sl <= cur_sl: - return False, f"新止损 {new_sl} 未高于当前止损 {cur_sl}(多仓需上移)" - else: - if new_sl >= cur_sl: - return False, f"新止损 {new_sl} 未低于当前止损 {cur_sl}(空仓需下移)" - ok_live, live_reason = m.ensure_exchange_live_ready() - if not ok_live: - return False, live_reason or "实盘未就绪" - plan_id = int(row["id"]) - try: - from strategy_snapshot_lib import save_trend_plan_snapshot - - save_trend_plan_snapshot( - cfg, conn, row, result_label="保本移交", exit_price=None, pnl_amount=None - ) - except Exception: - pass - handoff_row = { - "symbol": sym, - "exchange_symbol": ex_sym, - "direction": direction, - "order_amount": pos_amt, - } - try: - trend_replace_tpsl(cfg, handoff_row, new_sl, tp) - except Exception as e: - fe = getattr(m, "friendly_exchange_error", None) - return False, fe(e) if callable(fe) else str(e) - now_s = m.app_now_str() - _TREND_FLAT_STREAK.pop(plan_id, None) - cur = conn.execute( - "UPDATE trend_pullback_plans SET status=?, message=?, stop_loss=?, " - "breakeven_applied=1, breakeven_applied_at=? WHERE id=? AND status='active'", - ( - TREND_PLAN_STATUS_HANDOFF, - f"保本移交下单监控({TREND_HANDOFF_TRADE_NOTE})", - new_sl, - now_s, - plan_id, - ), - ) - if not getattr(cur, "rowcount", 0): - return False, "计划状态更新失败(可能已被其他操作结束)" - try: - mon_id = _insert_trend_handoff_order_monitor( - cfg, conn, row, new_sl=new_sl, pos_amt=pos_amt - ) - except Exception as e: - conn.execute( - "UPDATE trend_pullback_plans SET status='active', message=? WHERE id=?", - (f"移交下单监控失败:{e}", plan_id), - ) - return False, f"移交下单监控失败:{e}" - pct_used = float( - offset_pct if offset_pct is not None else cfg["breakeven_offset_pct"] - ) - extra = getattr(m, "build_wechat_close_message", None) - send = getattr(m, "send_wechat_msg", None) - pf = getattr(m, "format_price_for_symbol", None) - fmt = (lambda s, p: pf(s, p)) if callable(pf) else (lambda _s, p: str(p)) - try: - from strategy_wechat_notify import notify_trend_plan_ended - - notify_trend_plan_ended( - cfg, - plan_id=plan_id, - symbol=sym, - direction=direction, - end_type="保本移交", - result_label=TREND_HANDOFF_TRADE_NOTE, - extra=f"已移交下单监控 #{mon_id};止损 {fmt(sym, new_sl)} | 止盈 {fmt(sym, tp)}", - ) - except Exception: - pass - if callable(send): - lines = [ - f"# ✅ {sym} 趋势回调保本移交", - f"- 计划 ID:**{plan_id}** → 下单监控 **#{mon_id}**", - f"- 备注:**{TREND_HANDOFF_TRADE_NOTE}**", - f"- 保本止损:{fmt(sym, new_sl)} | 止盈:{fmt(sym, tp)}", - f"- 交易所:已挂止盈止损;平仓后将写入交易记录({ENTRY_REASON_TREND_PULLBACK})", - ] - wl = getattr(m, "_wechat_account_label", None) - if callable(wl): - lines.insert(1, f"**账户:{wl()}**") - send("\n".join(lines)) - return True, None - - -def load_trend_page_context(conn, request_obj, cfg: dict) -> dict[str, Any]: - m = _m(cfg) - _cleanup_stale_previews(conn) - trend_active = int( - conn.execute( - "SELECT COUNT(*) FROM trend_pullback_plans WHERE status='active'" - ).fetchone()[0] - or 0 - ) - trend_plans = [] - trend_dca_probes = [] - raw_plans = conn.execute( - "SELECT * FROM trend_pullback_plans WHERE status='active' ORDER BY id DESC" - ).fetchall() - for r in raw_plans: - try: - enriched = enrich_trend_plan(cfg, r) - trend_plans.append(enriched) - except Exception: - enriched = _row(cfg, r) - trend_plans.append(enriched) - try: - probe = summarize_trend_dca_probe(cfg, r) - trend_dca_probes.append(probe) - if isinstance(enriched, dict): - enriched["dca_probe"] = probe - except Exception: - pass - now = m.app_now() - active_count = m.get_active_position_count(conn) - from daily_open_limit_lib import can_trade_new_open, count_opens_for_trading_day - - trading_day = m.get_trading_day(now) - opens_today = count_opens_for_trading_day(conn, trading_day) - hard_limit = int(getattr(m, "DAILY_OPEN_HARD_LIMIT", 0) or 0) - can_trade_trend = can_trade_new_open( - time_allows=m.trading_day_reset_allows_new_open(now), - active_count=active_count, - max_active_positions=cfg["max_active_positions"], - opens_today=opens_today, - hard_limit=hard_limit, - extra_blocks=trend_active != 0, - ) - trend_preview = None - trend_preview_levels = [] - preview_expires_ms = None - trend_preview_expired = False - pid_arg = (request_obj.args.get("preview_id") or "").strip() - if pid_arg: - pr = conn.execute( - "SELECT * FROM trend_pullback_previews WHERE id=?", (pid_arg,) - ).fetchone() - now_ms = int(time.time() * 1000) - if pr and int(pr["expires_at_ms"] or 0) >= now_ms: - from strategy_trend_lib import build_trend_preview_level_rows - - trend_preview = _row(cfg, pr) - preview_expires_ms = int(pr["expires_at_ms"]) - get_cs = getattr(m, "get_contract_size", None) - if callable(get_cs) and not trend_preview.get("contract_size"): - try: - trend_preview["contract_size"] = float( - get_cs(trend_preview.get("exchange_symbol") or trend_preview.get("symbol") or "") - ) - except (TypeError, ValueError): - pass - trend_preview, trend_preview_levels = build_trend_preview_level_rows(trend_preview) - elif pr: - trend_preview_expired = True - return { - "trend_plans": trend_plans, - "trend_dca_probes": trend_dca_probes, - "trend_active": trend_active, - "can_trade_trend": can_trade_trend, - "trend_preview": trend_preview, - "trend_preview_levels": trend_preview_levels, - "preview_expires_ms": preview_expires_ms, - "trend_preview_expired": trend_preview_expired, - "trend_pullback_dca_legs": cfg["dca_legs"], - "trend_pullback_preview_ttl": cfg["preview_ttl"], - "trend_preview_max_drift_pct": cfg["drift_pct"], - "trend_manual_breakeven_offset_pct": cfg["breakeven_offset_pct"], - } - - -def register_trend_routes(app: Flask, cfg: dict) -> None: - lr = cfg["login_required"] - get_db = cfg["get_db"] - - def _redirect_trend(**kw): - return redirect(url_for("strategy_trading_page", **kw)) - - @app.route("/preview_trend_pullback", methods=["POST"]) - @lr - def preview_trend_pullback(): - conn = get_db() - init_strategy_tables(conn) - okp, msg = precheck_trend_start(cfg, conn) - if not okp: - conn.close() - flash(msg) - return _redirect_trend() - m = _m(cfg) - ok_live, reason = m.ensure_exchange_live_ready() - if not ok_live: - conn.close() - flash(reason) - return _redirect_trend() - payload, err = parse_trend_plan(cfg, request.form) - if err: - conn.close() - flash(err) - return _redirect_trend() - pid = str(uuid.uuid4()) - exp_ms = int(time.time() * 1000) + cfg["preview_ttl"] * 1000 - created = m.app_now_str() - conn.execute( - """INSERT INTO trend_pullback_previews ( - id,symbol,exchange_symbol,direction,leverage,stop_loss,add_upper,take_profit,risk_percent, - snapshot_available_usdt,snapshot_at,live_price_ref,plan_margin_capital,target_order_amount,first_order_amount,remainder_total, - dca_legs,per_leg_amount,grid_prices_json,leg_amounts_json,expires_at_ms,created_at - ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", - ( - pid, - payload["symbol"], - payload["exchange_symbol"], - payload["direction"], - payload["leverage"], - payload["stop_loss"], - payload["add_upper"], - payload["take_profit"], - payload["risk_percent"], - payload["snapshot_available_usdt"], - payload["snapshot_at"], - payload["live_price_ref"], - payload["plan_margin_capital"], - payload["target_order_amount"], - payload["first_order_amount"], - payload["remainder_total"], - payload["dca_legs"], - payload["per_leg_amount"], - payload["grid_prices_json"], - payload["leg_amounts_json"], - exp_ms, - created, - ), - ) - _insert_preview_snapshot(conn, pid, created, exp_ms, payload) - conn.commit() - conn.close() - flash(f"预览已生成,有效期 {cfg['preview_ttl']} 秒,请核对后点击「确认执行」。") - return _redirect_trend(preview_id=pid) - - @app.route("/execute_trend_pullback", methods=["POST"]) - @lr - def execute_trend_pullback(): - pid = (request.form.get("preview_id") or "").strip() - if not pid: - flash("缺少预览 ID") - return _redirect_trend() - conn = get_db() - init_strategy_tables(conn) - _cleanup_stale_previews(conn) - pr = conn.execute( - "SELECT * FROM trend_pullback_previews WHERE id=?", (pid,) - ).fetchone() - now_ms = int(time.time() * 1000) - if not pr or int(pr["expires_at_ms"] or 0) < now_ms: - conn.close() - flash("预览已过期或不存在,请重新生成预览") - return _redirect_trend() - okp, msg = precheck_trend_start(cfg, conn) - if not okp: - conn.close() - flash(msg) - return _redirect_trend(preview_id=pid) - m = _m(cfg) - ok_live, reason = m.ensure_exchange_live_ready() - if not ok_live: - conn.close() - flash(reason) - return _redirect_trend(preview_id=pid) - snap_prev = float(pr["snapshot_available_usdt"] or 0) - snap_now = m.get_available_trading_usdt() - if snap_now is None or snap_now <= 0: - conn.close() - flash("无法读取当前合约可用余额,请稍后重试") - return _redirect_trend(preview_id=pid) - drift = abs(float(snap_now) - snap_prev) / max(snap_prev, 1e-9) * 100.0 - if drift > cfg["drift_pct"]: - conn.close() - flash( - f"当前可用余额与预览快照偏差 {drift:.2f}%,超过允许 {cfg['drift_pct']}%,请重新生成预览" - ) - return _redirect_trend(preview_id=pid) - symbol = pr["symbol"] - exchange_symbol = pr["exchange_symbol"] - direction = pr["direction"] or "long" - leverage = int(pr["leverage"] or 1) - stop_loss = float(pr["stop_loss"]) - first_amt = float(pr["first_order_amount"] or 0) - live_price = m.get_price(symbol) - if live_price is None: - conn.close() - flash("获取实时价格失败") - return _redirect_trend(preview_id=pid) - try: - o1 = m.place_exchange_order( - exchange_symbol, direction, first_amt, leverage, stop_loss=None, take_profit=None - ) - fill1 = m.resolve_order_entry_price(o1, exchange_symbol, live_price) - trend_refresh_stop_only(cfg, exchange_symbol, direction, stop_loss) - except Exception as e: - conn.close() - fe = getattr(m, "friendly_exchange_error", lambda x, **k: str(x)) - flash(fe(e, available_usdt=snap_now)) - return _redirect_trend(preview_id=pid) - trading_day = m.get_trading_day(m.app_now()) - opened_at = m.app_now_str() - opened_ms = getattr(m, "_to_ms_with_fallback", lambda a, b: None)(None, opened_at) - from strategy_trend_lib import append_leg_fill_price_json - - fills_json = append_leg_fill_price_json(None, fill1) - cur = conn.execute( - """INSERT INTO trend_pullback_plans ( - status,symbol,exchange_symbol,direction,leverage,stop_loss,initial_stop_loss,add_upper,take_profit,risk_percent, - snapshot_available_usdt,snapshot_at,plan_margin_capital,target_order_amount,first_order_amount,remainder_total, - dca_legs,per_leg_amount,grid_prices_json,leg_amounts_json,legs_done,first_order_done,last_mark_price,avg_entry_price,order_amount_open,opened_at,opened_at_ms,session_date,message,leg_fill_prices_json - ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", - ( - "active", - symbol, - exchange_symbol, - direction, - leverage, - stop_loss, - stop_loss, - float(pr["add_upper"]), - float(pr["take_profit"]), - float(pr["risk_percent"] or 5), - float(snap_now), - opened_at, - float(pr["plan_margin_capital"] or 0), - float(pr["target_order_amount"] or 0), - first_amt, - float(pr["remainder_total"] or 0), - int(pr["dca_legs"] or 0), - float(pr["per_leg_amount"] or 0), - pr["grid_prices_json"] or "[]", - pr["leg_amounts_json"] or "[]", - 0, - 1, - float(live_price), - fill1, - first_amt, - opened_at, - opened_ms, - trading_day, - f"预览ID:{pid[:8]}…", - fills_json, - ), - ) - new_id = int(cur.lastrowid) - conn.execute( - "UPDATE trend_pullback_preview_snapshots SET outcome='executed', executed_plan_id=? WHERE preview_id=?", - (new_id, pid), - ) - conn.execute("DELETE FROM trend_pullback_previews WHERE id=?", (pid,)) - conn.commit() - try: - from strategy_wechat_notify import notify_trend_plan_started - - notify_trend_plan_started( - cfg, - plan_id=new_id, - symbol=symbol, - direction=direction, - leverage=leverage, - stop_loss=stop_loss, - take_profit=float(pr["take_profit"]), - add_upper=float(pr["add_upper"]), - risk_percent=float(pr["risk_percent"] or 5), - dca_legs=int(pr["dca_legs"] or 0), - first_order_amount=first_amt, - avg_entry=fill1, - snapshot_usdt=float(snap_now), - ) - except Exception: - pass - conn.close() - flash("趋势回调已执行:首仓已成交并挂交易所止损,止盈由程序监控。") - return _redirect_trend() - - @app.route("/cancel_trend_pullback_preview", methods=["POST"]) - @lr - def cancel_trend_pullback_preview(): - pid = (request.form.get("preview_id") or "").strip() - conn = get_db() - if pid: - conn.execute( - "UPDATE trend_pullback_preview_snapshots SET outcome='cancelled' WHERE preview_id=? AND outcome='open'", - (pid,), - ) - conn.execute("DELETE FROM trend_pullback_previews WHERE id=?", (pid,)) - conn.commit() - conn.close() - flash("已取消预览") - return _redirect_trend() - - @app.route("/trend_pullback_breakeven/", methods=["POST"]) - @lr - def trend_pullback_breakeven(pid: int): - offset_pct = None - raw = (request.form.get("breakeven_offset_pct") or "").strip() - if raw: - try: - offset_pct = float(raw) - if offset_pct < 0: - raise ValueError - except ValueError: - flash("保本偏移% 格式无效") - return _redirect_trend() - conn = get_db() - row = conn.execute( - "SELECT * FROM trend_pullback_plans WHERE id=? AND status='active'", (pid,) - ).fetchone() - if not row: - conn.close() - flash("未找到运行中的趋势回调计划") - return _redirect_trend() - ok, err = apply_manual_breakeven(cfg, conn, row, offset_pct=offset_pct) - conn.commit() - conn.close() - flash( - "已保本:趋势计划已结束,持仓已移交下单监控并挂止盈止损;平仓后将写入交易记录" - if ok - else (err or "保本移交失败") - ) - return _redirect_trend() - - @app.route("/stop_trend_pullback/") - @lr - def stop_trend_pullback(pid: int): - conn = get_db() - row = conn.execute( - "SELECT * FROM trend_pullback_plans WHERE id=? AND status='active'", (pid,) - ).fetchone() - if not row: - stopped = conn.execute( - "SELECT * FROM trend_pullback_plans WHERE id=? " - "AND status IN ('stopped_sl','stopped_tp','stopped_manual')", - (pid,), - ).fetchone() - if stopped and not _trend_plan_trade_exists(conn, pid): - try: - if _ensure_trend_plan_trade_record(cfg, conn, pid, prefer_label="手动平仓"): - conn.close() - flash("计划已结束,已补录缺失的交易记录") - return _redirect_trend() - except Exception as e: - conn.close() - flash(f"补录交易记录失败:{e}") - return _redirect_trend() - conn.close() - flash("未找到运行中的趋势回调计划") - return _redirect_trend() - m = _m(cfg) - ex_sym = row["exchange_symbol"] or m.normalize_exchange_symbol(row["symbol"]) - direction = row["direction"] or "long" - lev = int(row["leverage"] or 1) - px = m.get_price(row["symbol"]) - exit_p = float(px) if px is not None else 0.0 - ok_live, _ = m.ensure_exchange_live_ready() - if ok_live: - pos = m.get_live_position_contracts(ex_sym, direction) - if pos is not None and pos > 0: - try: - close_resp = trend_market_close(cfg, ex_sym, direction, float(pos), lev) - ep = m.extract_trade_price_from_order(close_resp) - if ep: - exit_p = float(ep) - except Exception as e: - if not m.is_no_position_error(str(e)): - conn.close() - flash(f"平仓失败:{e}") - return _redirect_trend() - try: - cancel_symbol_orders(cfg, ex_sym) - except Exception: - pass - try: - _finalize_plan(cfg, conn, row, "手动平仓", exit_p, user_initiated_risk=True) - except Exception as e: - conn.execute( - "UPDATE trend_pullback_plans SET status='stopped_manual', message=? " - "WHERE id=? AND status='active'", - (f"结束异常:{e}", pid), - ) - conn.commit() - conn.close() - flash(f"计划已结束但记账可能不完整:{e}") - return _redirect_trend() - conn.close() - flash("已结束趋势回调计划") - return _redirect_trend() +"""趋势回调:路由、轮询、页面数据(四所共用,依赖各 app 模块交易所能力)。""" +from __future__ import annotations + +import inspect +import json +import os +import time +import uuid +from typing import Any, Optional + +from flask import Flask, flash, redirect, request, url_for +from jinja2 import ChoiceLoader, FileSystemLoader + +from lib.strategy.strategy_config import resolve_trading_app_module +from lib.strategy.strategy_db import init_strategy_tables +from lib.strategy.strategy_trend_exchange import ( + cancel_symbol_orders, + trend_market_add, + trend_market_close, + trend_refresh_stop_only, + trend_replace_tpsl, +) +from lib.strategy.strategy_trend_lib import ( + build_grid_prices, + build_leg_amounts_json, + calc_risk_fraction, + trend_dca_level_reached, + trend_effective_margin_capital, + validate_trend_bounds, +) +from lib.strategy.strategy_trade_labels import ( + ENTRY_REASON_TREND_PULLBACK, + MONITOR_TYPE_TREND_PULLBACK, + TREND_HANDOFF_KEY_SIGNAL, + TREND_HANDOFF_TRADE_NOTE, +) + +MONITOR_TYPE_TREND = MONITOR_TYPE_TREND_PULLBACK + +# 趋势回调:交易所报空仓需连续 N 次轮询确认,避免 OKX 等 API 瞬时误判立即结束计划 +_TREND_FLAT_STREAK: dict[int, int] = {} +TREND_FLAT_CONFIRM_POLLS = max(1, int(os.getenv("TREND_FLAT_CONFIRM_POLLS", "5"))) +TREND_OPEN_GRACE_SEC = max(0, int(os.getenv("TREND_OPEN_GRACE_SEC", "180"))) +_TREND_LIVE_SKIP_LOG_TS = 0.0 +_TREND_POLL_STATE: dict[str, Any] = { + "updated_at": None, + "live_ok": True, + "live_reason": "", + "plans": {}, +} + + +def get_trend_poll_state() -> dict: + return dict(_TREND_POLL_STATE or {}) + + +def _log_trend_live_skip(reason: str) -> None: + global _TREND_LIVE_SKIP_LOG_TS + now = time.time() + if now - _TREND_LIVE_SKIP_LOG_TS < 60: + return + _TREND_LIVE_SKIP_LOG_TS = now + print(f"[trend_pullback] poll skipped (live not ready): {reason}", flush=True) + + +def _set_trend_poll_plan(plan_id: int, info: dict) -> None: + plans = dict(_TREND_POLL_STATE.get("plans") or {}) + plans[str(plan_id)] = info + _TREND_POLL_STATE["plans"] = plans + + +def summarize_trend_dca_probe(cfg: dict, row) -> dict: + """诊断单计划为何未补仓(供页面 / API)。""" + m = _m(cfg) + d = _row(cfg, row) + plan_id = int(d.get("id") or 0) + sym = d.get("symbol") or "" + direction = (d.get("direction") or "long").lower() + ex_sym = d.get("exchange_symbol") or m.normalize_exchange_symbol(sym) + out: dict[str, Any] = { + "plan_id": plan_id, + "symbol": sym, + "mark_price": None, + "next_trigger": None, + "trigger_reached": False, + "legs_done": int(d.get("legs_done") or 0), + "first_order_done": int(d.get("first_order_done") or 0), + "block_reason": None, + } + try: + legs_done = int(d.get("legs_done") or 0) + grid = json.loads(d.get("grid_prices_json") or "[]") + if not isinstance(grid, list): + grid = [] + leg_amounts = json.loads(d.get("leg_amounts_json") or "[]") + if not isinstance(leg_amounts, list): + leg_amounts = [] + except Exception: + grid = [] + leg_amounts = [] + legs_done = 0 + pf = _trend_poll_price(m, sym, ex_sym, direction) + out["mark_price"] = pf + ok_live, live_reason = m.ensure_exchange_live_ready() + out["live_ok"] = ok_live + if not ok_live: + out["block_reason"] = live_reason or "实盘未就绪" + if not int(d.get("first_order_done") or 0): + out["block_reason"] = out["block_reason"] or "首仓未完成" + return out + if legs_done >= len(grid) or legs_done >= len(leg_amounts): + out["block_reason"] = out["block_reason"] or "补仓档已全部完成或无 grid" + return out + try: + level = float(grid[legs_done]) + except (TypeError, ValueError, IndexError): + out["block_reason"] = out["block_reason"] or "无效补仓触发价" + return out + out["next_trigger"] = level + if pf is None: + out["block_reason"] = out["block_reason"] or "无法读取标记价" + return out + reached = trend_dca_level_reached(direction, float(pf), level) + out["trigger_reached"] = reached + if reached and not ok_live: + out["block_reason"] = live_reason or "LIVE_TRADING_ENABLED=false" + elif reached and ok_live: + pos = m.get_live_position_contracts(ex_sym, direction) + try: + local_open = float(d.get("order_amount_open") or 0) + except (TypeError, ValueError): + local_open = 0.0 + if pos is None and local_open > 0: + pos = local_open + if pos is None: + out["block_reason"] = "无法读取交易所持仓" + elif float(pos) <= 0: + out["block_reason"] = "交易所无持仓" + else: + out["block_reason"] = ( + "标记价已触达,轮询应自动下单;若仍未补请确认 PM2 进程 crypto_gate_bot " + "(非 manual-agent-gate-bot)在运行,并查看 pm2 logs crypto_gate_bot" + ) + elif not reached: + out["block_reason"] = f"标记价 {pf} 未触达下一档 {level}" + return out + + +def trend_add_zone_label(direction: str) -> str: + return "补仓下沿" if (direction or "long").strip().lower() == "short" else "补仓上沿" + + +def install_strategy_trend(app: Flask, repo_root: str, app_module: Any = None, **build_kw) -> dict: + from lib.strategy.strategy_register import attach_strategy_templates + + attach_strategy_templates(app, repo_root) + cfg = build_trend_config(app_module, **build_kw) + app.extensions["strategy_trend_cfg"] = cfg + register_trend_routes(app, cfg) + _patch_hub_monitor_enrich(app, cfg) + roll_cfg = app.extensions.get("strategy_roll_cfg") + if isinstance(roll_cfg, dict): + from lib.strategy.strategy_roll_ui_lib import patch_roll_hub_enrich + + patch_roll_hub_enrich(app, roll_cfg) + _patch_hub_trend_views(app) + + @app.context_processor + def _trend_ctx(): + return {"trend_add_zone_label": trend_add_zone_label} + + return cfg + + +def build_trend_config(app_module: Any = None, **kw) -> dict[str, Any]: + m = resolve_trading_app_module(app_module) + dca = max(1, int(os.getenv("TREND_PULLBACK_DCA_LEGS", kw.get("dca_legs", "5")))) + preview_ttl = max(10, int(os.getenv("TREND_PULLBACK_PREVIEW_TTL_SECONDS", "120"))) + drift = float(os.getenv("TREND_PREVIEW_MAX_BALANCE_DRIFT_PCT", "5")) + be_pct = float(os.getenv("TREND_PULLBACK_MANUAL_BREAKEVEN_OFFSET_PCT", "0.3")) + buf = float(getattr(m, "FULL_MARGIN_BUFFER_RATIO", 0.95)) + + def amount_precise(ex_sym, amt): + fn = getattr(m, "_safe_amount_to_precision", None) + if callable(fn): + return fn(ex_sym, amt) + try: + m.ensure_markets_loaded() + return float(m.exchange.amount_to_precision(ex_sym, float(amt))) + except Exception: + return None + + def send_wechat(content): + fn = getattr(m, "send_wechat_msg", None) + if callable(fn): + fn(content) + + def wechat_account_label(): + fn = getattr(m, "_wechat_account_label", None) + if callable(fn): + try: + return fn() + except Exception: + pass + return getattr(m, "EXCHANGE_DISPLAY_NAME", "") or "" + + def wechat_direction_text(direction): + fn = getattr(m, "_wechat_direction_text", None) + if callable(fn): + try: + return fn(direction) + except Exception: + pass + d = (direction or "long").strip().lower() + return "做多" if d == "long" else "做空" + + return { + "app_module": m, + "exchange_display": getattr(m, "EXCHANGE_DISPLAY_NAME", ""), + "login_required": m.login_required, + "get_db": m.get_db, + "row_to_dict": m.row_to_dict, + "dca_legs": dca, + "preview_ttl": preview_ttl, + "drift_pct": drift, + "breakeven_offset_pct": be_pct, + "margin_buffer": buf, + "amount_precise": amount_precise, + "max_active_positions": int(getattr(m, "MAX_ACTIVE_POSITIONS", 1)), + "reset_hour": int(getattr(m, "TRADING_DAY_RESET_HOUR", 8)), + "monitor_type_trend": MONITOR_TYPE_TREND, + "send_wechat": send_wechat, + "format_price": getattr(m, "format_price_for_symbol", None), + "wechat_account_label": wechat_account_label, + "wechat_direction_text": wechat_direction_text, + } + + +def _m(cfg: dict): + return cfg["app_module"] + + +def _row(cfg, row) -> dict: + return cfg["row_to_dict"](row) + + +def precheck_trend_start(cfg: dict, conn) -> tuple[bool, str]: + m = _m(cfg) + mode = getattr(m, "POSITION_SIZING_MODE", None) or "risk" + try: + from lib.trade.position_sizing_lib import OPEN_SOURCE_TREND, assert_open_source_allowed + + ok_src, src_msg = assert_open_source_allowed(mode, OPEN_SOURCE_TREND) + if not ok_src: + return False, src_msg + except Exception: + pass + now = m.app_now() + if not m.trading_day_reset_allows_new_open(now): + return False, f"北京时间 {cfg['reset_hour']}:00 前不允许持仓" + active = m.get_active_position_count(conn) + if active >= cfg["max_active_positions"]: + return ( + False, + f"已达最大持仓数({active}/{cfg['max_active_positions']})," + "请先结束「实盘下单」中的持仓,再启动趋势回调", + ) + trend_n = conn.execute( + "SELECT COUNT(*) FROM trend_pullback_plans WHERE status='active'" + ).fetchone()[0] + if int(trend_n or 0) > 0: + return False, "已存在运行中的趋势回调计划" + return True, "" + + +def _cleanup_stale_previews(conn) -> None: + ms = int(time.time() * 1000) + stale = conn.execute( + "SELECT id FROM trend_pullback_previews WHERE expires_at_ms < ?", (ms,) + ).fetchall() + for row in stale: + try: + conn.execute( + "UPDATE trend_pullback_preview_snapshots SET outcome='expired' " + "WHERE preview_id=? AND outcome='open'", + (row["id"],), + ) + except Exception: + pass + conn.execute("DELETE FROM trend_pullback_previews WHERE expires_at_ms < ?", (ms,)) + + +def parse_trend_plan(cfg: dict, form_dict) -> tuple[Optional[dict], Optional[str]]: + m = _m(cfg) + d = form_dict or {} + symbol = m.normalize_symbol_input(d.get("symbol")) + if not symbol: + return None, "symbol 不能为空" + direction = (d.get("direction") or "long").strip().lower() + if direction not in ("long", "short"): + return None, "方向错误" + try: + stop_loss = float(d.get("sl")) + add_upper = float(d.get("add_upper")) + take_profit = float(d.get("take_profit")) + risk_percent = float(d.get("risk_percent") or "5") + except Exception: + return None, "价格或风险比例格式错误" + try: + lev_raw = m.parse_positive_float(d.get("leverage")) + leverage = int(lev_raw) if lev_raw is not None else m.infer_leverage(symbol) + except Exception: + return None, "杠杆格式错误" + if leverage <= 0 or risk_percent <= 0: + return None, "杠杆与风险比例必须大于0" + bound_err = validate_trend_bounds(direction, stop_loss, add_upper) + if bound_err: + return None, bound_err + snap = m.get_available_trading_usdt() + if snap is None or snap <= 0: + return None, "无法读取合约账户 USDT 可用余额,请检查 API 与账户类型" + live_price = m.get_price(symbol) + if live_price is None: + return None, "获取实时价格失败" + exchange_symbol = m.normalize_exchange_symbol(symbol) + rf = calc_risk_fraction(direction, add_upper, stop_loss) + if rf is None or rf <= 0: + return None, "止损与补仓区间边界组合无法计算风险比例" + risk_budget = float(snap) * (risk_percent / 100.0) + notional = risk_budget / rf + margin_plan = notional / float(leverage) + margin_plan = min(margin_plan, float(snap) * cfg["margin_buffer"]) + if margin_plan <= 0: + return None, "计划保证金过小" + try: + target_amt, _ = m.prepare_order_amount(exchange_symbol, margin_plan, leverage, live_price) + except Exception as e: + return None, str(e) + ap = cfg["amount_precise"] + first_amt = ap(exchange_symbol, float(target_amt) * 0.5) + if first_amt is None or first_amt <= 0: + return None, "首仓张数过小(低于交易所最小张数),请提高风险比例或杠杆" + remainder_total = ap(exchange_symbol, max(0.0, float(target_amt) - float(first_amt))) + if remainder_total is None: + remainder_total = 0.0 + m.ensure_markets_loaded() + market = m.exchange.market(exchange_symbol) + min_amt = float((market.get("limits", {}).get("amount", {}) or {}).get("min") or 0) + n_legs, leg_json, per_ref = build_leg_amounts_json( + exchange_symbol, remainder_total, cfg["dca_legs"], ap, min_amt + ) + if n_legs <= 0: + return None, "剩余计划张数不足以拆出补仓档,请提高风险比例或放宽止损与补仓区间间距" + grid = build_grid_prices(direction, stop_loss, add_upper, n_legs) + if len(grid) != n_legs: + return None, "补仓网格生成失败" + opened_at = m.app_now_str() + try: + leg_list = json.loads(leg_json) + except Exception: + leg_list = [] + contract_size = float(market.get("contractSize") or 1) + return { + "symbol": symbol, + "exchange_symbol": exchange_symbol, + "direction": direction, + "leverage": leverage, + "stop_loss": stop_loss, + "add_upper": add_upper, + "take_profit": take_profit, + "risk_percent": risk_percent, + "snapshot_available_usdt": float(snap), + "snapshot_at": opened_at, + "live_price_ref": float(live_price), + "plan_margin_capital": float(margin_plan), + "target_order_amount": float(target_amt), + "first_order_amount": float(first_amt), + "remainder_total": float(remainder_total), + "dca_legs": int(n_legs), + "per_leg_amount": float(per_ref), + "grid_prices_json": json.dumps(grid), + "leg_amounts_json": leg_json, + "grid": grid, + "leg_amounts": leg_list, + "contract_size": contract_size, + }, None + + +def _insert_preview_snapshot(conn, preview_id: str, created: str, exp_ms: int, pl: dict) -> None: + conn.execute( + """INSERT INTO trend_pullback_preview_snapshots ( + preview_id,symbol,exchange_symbol,direction,leverage,stop_loss,add_upper,take_profit,risk_percent, + snapshot_available_usdt,snapshot_at,live_price_ref,plan_margin_capital,target_order_amount,first_order_amount,remainder_total, + dca_legs,per_leg_amount,grid_prices_json,leg_amounts_json,expires_at_ms,preview_created_at + ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", + ( + preview_id, + pl["symbol"], + pl["exchange_symbol"], + pl["direction"], + pl["leverage"], + pl["stop_loss"], + pl["add_upper"], + pl["take_profit"], + pl["risk_percent"], + pl["snapshot_available_usdt"], + pl["snapshot_at"], + pl["live_price_ref"], + pl["plan_margin_capital"], + pl["target_order_amount"], + pl["first_order_amount"], + pl["remainder_total"], + pl["dca_legs"], + pl["per_leg_amount"], + pl["grid_prices_json"], + pl["leg_amounts_json"], + exp_ms, + created, + ), + ) + + +def _format_trend_price(cfg: dict, symbol: str, value) -> str: + if value in (None, ""): + return "—" + m = _m(cfg) + sym = symbol or "" + norm = getattr(m, "normalize_exchange_symbol", None) + if callable(norm): + try: + sym = norm(sym) or sym + except Exception: + pass + try: + m.ensure_markets_loaded() + return str(m.exchange.price_to_precision(sym, float(value))) + except Exception: + fn = getattr(m, "format_price_for_symbol", None) + if callable(fn): + return fn(symbol, value) + return str(value) + + +def _trend_add_leg_fields(cfg: dict, d: dict) -> dict: + """解析已补仓次数与已触达网格价(供策略页与中控 monitor 共用)。""" + import json + + out = dict(d) + try: + legs_done = int(out.get("legs_done") or 0) + except (TypeError, ValueError): + legs_done = 0 + try: + dca_legs = int(out.get("dca_legs") or 0) + except (TypeError, ValueError): + dca_legs = 0 + try: + grid = json.loads(out.get("grid_prices_json") or "[]") + if not isinstance(grid, list): + grid = [] + except Exception: + grid = [] + add_prices: list[float] = [] + try: + from lib.strategy.strategy_trend_lib import trend_leg_display_price + + for i in range(1, legs_done + 1): + px = trend_leg_display_price(out, i) + if px is not None: + add_prices.append(float(px)) + except Exception: + pass + if not add_prices: + for x in grid[:legs_done]: + try: + add_prices.append(float(x)) + except (TypeError, ValueError): + pass + sym = out.get("exchange_symbol") or out.get("symbol") or "" + out["add_count"] = legs_done + out["add_count_total"] = dca_legs + out["add_prices"] = add_prices + out["add_prices_display"] = [_format_trend_price(cfg, sym, p) for p in add_prices] + for field in ("stop_loss", "take_profit", "add_upper", "avg_entry_price"): + if out.get(field) not in (None, ""): + out[f"{field}_display"] = _format_trend_price(cfg, sym, out.get(field)) + return out + + +def enrich_trend_plan_for_hub(cfg: dict, raw: dict) -> dict: + """中控 /api/hub/monitor:与策略页运行中计划卡片同字段(浮盈亏、标记价、盈亏比等)。""" + d = enrich_trend_plan(cfg, dict(raw or {})) + d["monitor_source"] = "趋势回调计划" + m = _m(cfg) + try: + snap = float(d.get("snapshot_available_usdt") or 0) + margin = float(d.get("plan_margin_capital") or 0) + if snap > 0 and margin > 0: + d["position_ratio_pct"] = round(margin / snap * 100.0, 2) + except (TypeError, ValueError): + pass + return d + + +def _patch_hub_trend_views(app: Flask) -> None: + """将趋势回调路由注册进 HUB_CTX.views,供中控 /api/hub/trend/* 调用。""" + ctx = dict(app.config.get("HUB_CTX") or {}) + views = dict(ctx.get("views") or {}) + for name in ( + "preview_trend_pullback", + "execute_trend_pullback", + "stop_trend_pullback", + "trend_pullback_breakeven", + ): + vf = app.view_functions.get(name) + if vf is not None: + views[name] = vf + ctx["views"] = views + app.config["HUB_CTX"] = ctx + + +def patch_trend_hub_enrich(app: Flask, cfg: dict) -> None: + """hub_bridge install 之后调用:四所 /api/hub/monitor 趋势字段与策略页一致。""" + _patch_hub_monitor_enrich(app, cfg) + + +def _patch_hub_monitor_enrich(app: Flask, cfg: dict) -> None: + ctx = dict(app.config.get("HUB_CTX") or {}) + prev = ctx.get("enrich_monitor") + + def enrich_monitor(keys=None, orders=None, trends=None, rolls=None): + payload: dict[str, Any] = {} + if callable(prev): + try: + prev_out = prev(keys=keys, orders=orders, trends=trends, rolls=rolls) + if isinstance(prev_out, dict): + payload.update(prev_out) + except Exception: + pass + if trends: + payload["trends"] = [ + enrich_trend_plan_for_hub(cfg, t) for t in trends if isinstance(t, dict) + ] + return payload + + ctx["enrich_monitor"] = enrich_monitor + app.config["HUB_CTX"] = ctx + + +def enrich_trend_plan(cfg: dict, row) -> dict: + m = _m(cfg) + d = _row(cfg, row) + try: + d["breakeven_applied"] = int(d.get("breakeven_applied") or 0) != 0 + except Exception: + d["breakeven_applied"] = False + ex_sym = d.get("exchange_symbol") or m.normalize_exchange_symbol(d.get("symbol") or "") + direction = (d.get("direction") or "long").lower() + metrics_fn = getattr(m, "get_live_position_exchange_metrics", None) + met = None + if callable(metrics_fn): + try: + lev = int(d.get("leverage") or 0) or None + except (TypeError, ValueError): + lev = None + try: + met = metrics_fn(ex_sym, direction, order_leverage=lev) + except TypeError: + met = metrics_fn(ex_sym, direction) + if met and met.get("entry_price") is not None: + try: + live_entry = float(met["entry_price"]) + if live_entry > 0: + d["avg_entry_price"] = live_entry + except (TypeError, ValueError): + pass + if met and met.get("unrealized_pnl") is not None: + d["floating_pnl"] = float(met["unrealized_pnl"]) + elif ( + met + and met.get("mark_price") is not None + and d.get("avg_entry_price") is not None + ): + try: + from lib.hub.hub_position_metrics import estimate_linear_swap_upnl_usdt + + entry = float(d["avg_entry_price"]) + mark = float(met["mark_price"]) + qty = None + cs = 1.0 + get_qty = getattr(m, "get_live_position_contracts", None) + get_cs = getattr(m, "get_contract_size", None) + if callable(get_qty): + qty = get_qty(ex_sym, direction) + if callable(get_cs): + cs = float(get_cs(ex_sym)) + upnl = estimate_linear_swap_upnl_usdt( + direction, entry, mark, qty, cs + ) + d["floating_pnl"] = float(upnl) if upnl is not None else None + except (TypeError, ValueError): + d["floating_pnl"] = None + else: + d["floating_pnl"] = None + if met and met.get("mark_price") is not None: + d["floating_mark"] = float(met["mark_price"]) + else: + d["floating_mark"] = None + else: + d["floating_pnl"] = d["floating_mark"] = None + get_cs = getattr(m, "get_contract_size", None) + if callable(get_cs): + try: + d["contract_size"] = float(get_cs(ex_sym)) + except (TypeError, ValueError): + pass + d = _trend_add_leg_fields(cfg, d) + from lib.strategy.strategy_snapshot_lib import attach_trend_dca_levels + from lib.strategy.strategy_trend_lib import calc_trend_plan_money_metrics + + d = attach_trend_dca_levels(d) + money = calc_trend_plan_money_metrics(d) + if money.get("money_rr") is not None: + d["money_rr"] = money["money_rr"] + d["planned_rr"] = money["money_rr"] + if money.get("risk_amount_u") is not None: + d["risk_amount_u"] = money["risk_amount_u"] + try: + d["breakeven_default_offset_pct"] = float(cfg.get("breakeven_offset_pct", 0.3)) + except (TypeError, ValueError): + d["breakeven_default_offset_pct"] = 0.3 + return d + + +def _weighted_avg(old_avg, old_amt, fill_px, add_amt): + try: + oa, aa = float(old_amt), float(add_amt) + if oa <= 0: + return float(fill_px) + return (float(old_avg) * oa + float(fill_px) * aa) / (oa + aa) + except Exception: + return float(fill_px or 0) + + +def _plan_stop_status(result_label: str) -> str: + if result_label == "止盈": + return "stopped_tp" + if result_label == "止损": + return "stopped_sl" + return "stopped_manual" + + +def _call_insert_trade_record(m, plan_id: int, kwargs: dict) -> None: + """按各所 insert_trade_record 签名过滤参数,避免未知字段导致记账失败。""" + fn = getattr(m, "insert_trade_record", None) + if not callable(fn): + raise RuntimeError("app_module 缺少 insert_trade_record") + allowed = set(inspect.signature(fn).parameters.keys()) + call = {k: v for k, v in kwargs.items() if k in allowed} + if "trend_plan_id" in allowed: + call["trend_plan_id"] = int(plan_id) + fn(**call) + + +def _best_trend_close_snapshot(conn, plan_id: int) -> dict | None: + from lib.strategy.strategy_snapshot_lib import ( + FINAL_TREND_CLOSE_LABELS, + STRATEGY_TREND, + _final_trend_close_rank, + ) + + rows = conn.execute( + f"""SELECT * FROM strategy_trade_snapshots + WHERE strategy_type=? AND source_id=? + AND result_label IN ({",".join("?" * len(FINAL_TREND_CLOSE_LABELS))})""", + (STRATEGY_TREND, int(plan_id), *FINAL_TREND_CLOSE_LABELS), + ).fetchall() + if not rows: + return None + parsed = [_row_dict(row) for row in rows] + return max( + parsed, + key=lambda d: ( + _final_trend_close_rank(str(d.get("result_label") or "")), + int(d.get("id") or 0), + ), + ) + + +def _ensure_trend_plan_trade_record( + cfg: dict, conn, plan_id: int, *, prefer_label: str = "手动平仓" +) -> bool: + """计划已结束但 trade_records 缺失时,从策略快照补录一条。""" + if _trend_plan_trade_exists(conn, plan_id): + return True + m = _m(cfg) + plan = conn.execute( + "SELECT * FROM trend_pullback_plans WHERE id=?", (int(plan_id),) + ).fetchone() + if not plan: + return False + plan_d = _row_dict(plan) + snap = _best_trend_close_snapshot(conn, plan_id) + if not snap: + return False + try: + payload = json.loads(snap.get("snapshot_json") or "{}") + except Exception: + payload = {} + sym = snap.get("symbol") or plan_d.get("symbol") or payload.get("symbol") + direction = snap.get("direction") or plan_d.get("direction") or "long" + result = (prefer_label or "").strip() or (snap.get("result_label") or "").strip() or "手动平仓" + opened_at = snap.get("opened_at") or plan_d.get("opened_at") + closed_at = snap.get("closed_at") + pnl_amount = snap.get("pnl_amount") + if pnl_amount is None: + pnl_amount = payload.get("pnl_amount") + avg_e = float(payload.get("avg_entry_price") or plan_d.get("avg_entry_price") or 0) + margin_cap = trend_effective_margin_capital(plan_d) + lev = int(plan_d.get("leverage") or 1) + hold_seconds = m.calc_hold_seconds( + opened_at or "", + m.parse_dt_for_trading_day(closed_at) or m.app_now(), + ) + res = m.normalize_result_with_pnl(result, float(pnl_amount or 0)) + risk_amt = m.calc_risk_amount_from_plan( + direction, + float(plan_d.get("add_upper") or 0), + float(plan_d.get("stop_loss") or 0), + float(plan_d.get("plan_margin_capital") or 0), + lev, + ) + planned_rr = m.calc_rr_ratio( + direction, + avg_e, + float(plan_d.get("stop_loss") or 0), + float(plan_d.get("take_profit") or 0), + ) + session_date = plan_d.get("session_date") or m.get_trading_day() + _bump_session_capital_no_commit(m, conn, session_date, float(pnl_amount or 0)) + _call_insert_trade_record( + m, + plan_id, + dict( + conn=conn, + symbol=sym, + monitor_type=MONITOR_TYPE_TREND, + direction=direction, + trigger_price=avg_e, + stop_loss=float(plan_d.get("stop_loss") or 0), + initial_stop_loss=float(plan_d.get("initial_stop_loss") or plan_d.get("stop_loss") or 0), + take_profit=float(plan_d.get("take_profit") or 0), + margin_capital=margin_cap, + leverage=lev, + pnl_amount=pnl_amount, + hold_seconds=hold_seconds, + trade_style="trend_pullback", + risk_amount=risk_amt, + planned_rr=planned_rr, + actual_rr=m.calc_actual_rr(pnl_amount, risk_amt), + result=res, + opened_at=opened_at, + closed_at=closed_at, + entry_reason=ENTRY_REASON_TREND_PULLBACK, + ), + ) + conn.commit() + return True + + +def sync_trend_plans_after_external_close( + cfg: dict, conn, symbol: str, direction: str +) -> dict[str, Any]: + """中控/外部全平后:结束仍 active 的同币种同向趋势计划(避免监控再记一条止损)。""" + m = _m(cfg) + sym = m.normalize_symbol_input(symbol) if hasattr(m, "normalize_symbol_input") else (symbol or "").strip() + if not sym: + return {"ok": False, "msg": "symbol 无效", "finalized": 0} + direction = (direction or "long").strip().lower() + rows = conn.execute( + "SELECT * FROM trend_pullback_plans WHERE status='active' AND symbol=? AND direction=?", + (sym, direction), + ).fetchall() + finalized = 0 + for row in rows: + px = m.get_price(row["symbol"]) + exit_p = float(px) if px is not None else 0.0 + before = _trend_plan_trade_exists(conn, int(row["id"])) + _finalize_plan(cfg, conn, row, "手动平仓", exit_p) + if not before: + finalized += 1 + return {"ok": True, "finalized": finalized, "symbol": sym, "direction": direction} + + +def _trend_plan_trade_exists(conn, plan_id: int) -> bool: + try: + return conn.execute( + "SELECT id FROM trade_records WHERE trend_plan_id=? LIMIT 1", + (int(plan_id),), + ).fetchone() is not None + except Exception: + return False + + +def _bump_session_capital_no_commit( + m, conn, session_date: str, pnl_amount: float +) -> float | None: + """更新当日资金,不单独 commit(与 _finalize_plan 同一事务)。""" + try: + row = conn.execute( + "SELECT current_capital FROM trading_sessions WHERE session_date = ?", + (session_date,), + ).fetchone() + if not row: + start_cap = float(getattr(m, "DAILY_START_CAPITAL", 0) or 0) + if start_cap <= 0: + ensure = getattr(m, "ensure_session", None) + if callable(ensure): + ensured = ensure(conn, session_date) + row = ensured + else: + return None + else: + conn.execute( + "INSERT OR IGNORE INTO trading_sessions " + "(session_date, start_capital, current_capital) VALUES (?,?,?)", + (session_date, start_cap, start_cap), + ) + row = conn.execute( + "SELECT current_capital FROM trading_sessions WHERE session_date = ?", + (session_date,), + ).fetchone() + if not row: + return None + new_capital = float(row["current_capital"]) + float(pnl_amount) + conn.execute( + "UPDATE trading_sessions SET current_capital = ?, updated_at = CURRENT_TIMESTAMP " + "WHERE session_date = ?", + (round(new_capital, 4), session_date), + ) + return round(new_capital, 4) + except Exception: + return None + + +def _apply_trend_user_risk_close(cfg: dict, conn, *, trade_record_id=None, closed_at_ms=None) -> None: + m = _m(cfg) + fn = getattr(m, "hub_user_initiated_close", None) + from lib.trade.account_risk_lib import CLOSE_SOURCE_USER_TREND_STOP + + if callable(fn): + fn( + conn, + source=CLOSE_SOURCE_USER_TREND_STOP, + count=1, + trade_record_id=trade_record_id, + closed_at_ms=closed_at_ms, + ) + return + from lib.trade.account_risk_lib import on_user_initiated_close + + on_user_initiated_close( + conn, + source=CLOSE_SOURCE_USER_TREND_STOP, + trade_record_id=trade_record_id, + closed_at_ms=closed_at_ms, + trading_day=m.get_trading_day(), + now=m.app_now(), + count=1, + ) + + +def _finalize_plan(cfg: dict, conn, row, result_label: str, exit_price: float, *, user_initiated_risk: bool = False) -> None: + m = _m(cfg) + plan_id = int(row["id"]) + active = conn.execute( + "SELECT * FROM trend_pullback_plans WHERE id=? AND status='active'", + (plan_id,), + ).fetchone() + if not active: + return + row = active + sym = row["symbol"] + direction = row["direction"] or "long" + ex_sym = row["exchange_symbol"] or m.normalize_exchange_symbol(sym) + closed_at = m.app_now_str() + opened_at = row["opened_at"] or closed_at + hold_seconds = m.calc_hold_seconds(opened_at, m.parse_dt_for_trading_day(closed_at) or m.app_now()) + plan_margin = float(row["plan_margin_capital"] or 0) + margin_cap = trend_effective_margin_capital(_row(cfg, row)) + lev = int(row["leverage"] or 1) + avg_e = float(row["avg_entry_price"] or 0) + pnl_amount = m.calc_pnl(direction, avg_e, float(exit_price), margin_cap, lev) + res = m.normalize_result_with_pnl(result_label, pnl_amount) + risk_amt = m.calc_risk_amount_from_plan( + direction, float(row["add_upper"]), float(row["stop_loss"]), plan_margin, lev + ) + try: + target = float(row["target_order_amount"] or 0) + open_amt = float(row["order_amount_open"] or 0) + if risk_amt is not None and target > 0 and open_amt > 0: + risk_amt = round(float(risk_amt) * min(1.0, open_amt / target), 6) + except (TypeError, ValueError): + pass + planned_rr = m.calc_rr_ratio(direction, avg_e, float(row["stop_loss"]), float(row["take_profit"])) + st = _plan_stop_status(result_label) + cur = conn.execute( + "UPDATE trend_pullback_plans SET status=?, message=? WHERE id=? AND status='active'", + (st, res, plan_id), + ) + if not getattr(cur, "rowcount", 0): + return + try: + from lib.strategy.strategy_snapshot_lib import save_trend_plan_snapshot + + save_trend_plan_snapshot( + cfg, + conn, + row, + result_label=result_label, + exit_price=float(exit_price) if exit_price is not None else None, + pnl_amount=float(pnl_amount) if pnl_amount is not None else None, + closed_at=closed_at, + ) + except Exception: + pass + try: + cancel_symbol_orders(cfg, ex_sym) + except Exception: + pass + session_capital = None + trade_record_id = None + if not _trend_plan_trade_exists(conn, plan_id): + session_date = row["session_date"] or m.get_trading_day() + session_capital = _bump_session_capital_no_commit( + m, conn, session_date, pnl_amount + ) + _call_insert_trade_record( + m, + plan_id, + dict( + conn=conn, + symbol=sym, + monitor_type=MONITOR_TYPE_TREND, + direction=direction, + trigger_price=avg_e, + stop_loss=float(row["stop_loss"]), + initial_stop_loss=float(row.get("initial_stop_loss") or row["stop_loss"]), + take_profit=float(row["take_profit"]), + margin_capital=margin_cap, + leverage=lev, + pnl_amount=pnl_amount, + hold_seconds=hold_seconds, + trade_style="trend_pullback", + risk_amount=risk_amt, + planned_rr=planned_rr, + actual_rr=m.calc_actual_rr(pnl_amount, risk_amt), + result=res, + opened_at=opened_at, + closed_at=closed_at, + entry_reason=ENTRY_REASON_TREND_PULLBACK, + ), + ) + try: + from lib.trade.account_risk_lib import insert_trade_record_id + + trade_record_id = insert_trade_record_id(conn) + except Exception: + trade_record_id = None + if user_initiated_risk: + closed_ms = None + to_ms = getattr(m, "_to_ms_with_fallback", None) + if callable(to_ms): + try: + closed_ms = to_ms(None, closed_at) + except Exception: + closed_ms = None + _apply_trend_user_risk_close( + cfg, + conn, + trade_record_id=trade_record_id, + closed_at_ms=closed_ms, + ) + conn.commit() + try: + from lib.strategy.strategy_wechat_notify import notify_trend_plan_ended + + notify_trend_plan_ended( + cfg, + plan_id=plan_id, + symbol=sym, + direction=direction, + end_type=result_label, + result_label=res, + exit_price=float(exit_price) if exit_price is not None else None, + pnl_amount=float(pnl_amount) if pnl_amount is not None else None, + ) + except Exception: + pass + extra = getattr(m, "build_wechat_close_message", None) + send = getattr(m, "send_wechat_msg", None) + if callable(extra) and callable(send): + send( + extra( + symbol=sym, + direction=direction, + result=f"{res}({MONITOR_TYPE_TREND})", + pnl_amount=pnl_amount, + hold_seconds=hold_seconds, + trigger_price=avg_e, + current_price=float(exit_price), + stop_loss=float(row["stop_loss"]), + take_profit=float(row["take_profit"]), + close_order_id="-", + extra_note="计划本金口径:启动时合约可用余额快照;止盈由程序监控", + session_capital_fallback=session_capital, + ) + ) + + +def _trend_plan_open_age_sec(row, m) -> float: + opened_ms = None + try: + if "opened_at_ms" in row.keys() and row["opened_at_ms"]: + opened_ms = int(row["opened_at_ms"]) + except Exception: + opened_ms = None + to_ms = getattr(m, "_to_ms_with_fallback", None) + if callable(to_ms): + opened_ms = to_ms(opened_ms, row["opened_at"] if "opened_at" in row.keys() else None) + if opened_ms is None and "opened_at" in row.keys(): + opened_ms = to_ms(None, row["opened_at"]) + if not opened_ms: + return 0.0 + return max(0.0, (time.time() * 1000 - opened_ms) / 1000.0) + + +def _trend_hit_take_profit(direction: str, mark_price: float, take_profit: float, avg_entry: float) -> bool: + try: + pf = float(mark_price) + tp = float(take_profit) + entry = float(avg_entry) + except (TypeError, ValueError): + return False + if entry <= 0 or tp <= 0: + return False + direction = (direction or "long").lower() + if direction == "long": + return tp > entry and pf >= tp + return tp < entry and pf <= tp + + +def _trend_poll_price(m, sym: str, ex_sym: str, direction: str) -> Optional[float]: + """补仓/止盈判定用标记价(与页面「标记价」一致),无标记价时回退 last。""" + fn = getattr(m, "get_symbol_mark_price", None) + if callable(fn): + try: + px = fn(sym) + if px is not None and float(px) > 0: + return float(px) + except Exception: + pass + metrics_fn = getattr(m, "get_live_position_exchange_metrics", None) + if callable(metrics_fn): + try: + met = metrics_fn(ex_sym, direction) + if met and met.get("mark_price") is not None: + px = float(met["mark_price"]) + if px > 0: + return px + except Exception: + pass + px = m.get_price(sym) + try: + return float(px) if px is not None else None + except (TypeError, ValueError): + return None + + +def _should_finalize_trend_flat(row, pos, plan_id: int, m) -> bool: + """首仓后交易所报无仓:需过开仓宽限期 + 连续空仓轮询,避免误判止损。""" + if pos is None: + return False + if float(pos) > 0: + _TREND_FLAT_STREAK.pop(plan_id, None) + return False + if not int(row["first_order_done"] or 0): + return False + age = _trend_plan_open_age_sec(row, m) + if age < TREND_OPEN_GRACE_SEC: + _TREND_FLAT_STREAK.pop(plan_id, None) + return False + try: + local_open = float(row["order_amount_open"] or 0) + except (TypeError, ValueError): + local_open = 0.0 + required = TREND_FLAT_CONFIRM_POLLS + if local_open > 0 and age < TREND_OPEN_GRACE_SEC * 2: + required = max(required, TREND_FLAT_CONFIRM_POLLS * 2) + streak = int(_TREND_FLAT_STREAK.get(plan_id, 0)) + 1 + _TREND_FLAT_STREAK[plan_id] = streak + if streak >= required: + print( + f"[trend_pullback] flat finalize plan={plan_id} sym={row['symbol']} " + f"age={age:.0f}s streak={streak} local_open={local_open}", + flush=True, + ) + return True + return False + + +def check_trend_pullback_plans(cfg: dict) -> None: + m = _m(cfg) + ok_live, live_reason = m.ensure_exchange_live_ready() + _TREND_POLL_STATE["updated_at"] = time.time() + _TREND_POLL_STATE["live_ok"] = ok_live + _TREND_POLL_STATE["live_reason"] = live_reason or "" + if not ok_live: + _log_trend_live_skip(live_reason or "unknown") + conn = cfg["get_db"]() + try: + for row in conn.execute( + "SELECT * FROM trend_pullback_plans WHERE status='active'" + ).fetchall(): + probe = summarize_trend_dca_probe(cfg, row) + if probe.get("trigger_reached"): + _set_trend_poll_plan(int(row["id"]), probe) + except Exception as e: + print(f"[trend_pullback] live-skip probe error: {e}", flush=True) + finally: + conn.close() + return + conn = cfg["get_db"]() + rows = conn.execute( + "SELECT * FROM trend_pullback_plans WHERE status='active'" + ).fetchall() + for row in rows: + try: + plan_id = int(row["id"]) + sym = row["symbol"] + direction = (row["direction"] or "long").lower() + ex_sym = row["exchange_symbol"] or m.normalize_exchange_symbol(sym) + sl = float(row["stop_loss"]) + tp = float(row["take_profit"]) + lev = int(row["leverage"] or 1) + try: + local_open = float(row["order_amount_open"] or 0) + except (TypeError, ValueError): + local_open = 0.0 + pf = _trend_poll_price(m, sym, ex_sym, direction) + if pf is None: + continue + last_p = row["last_mark_price"] + last_pf = float(last_p) if last_p is not None else pf + pos = m.get_live_position_contracts(ex_sym, direction) + if pos is None: + if local_open > 0 and int(row["first_order_done"] or 0): + pos = local_open + else: + continue + elif float(pos) <= 0 and local_open > 0: + age = _trend_plan_open_age_sec(row, m) + if age < TREND_OPEN_GRACE_SEC * 2: + print( + f"[trend_pullback] pos fallback plan={plan_id} sym={sym} " + f"ex_pos=0 local_open={local_open} age={age:.0f}s", + flush=True, + ) + pos = local_open + legs_done = int(row["legs_done"] or 0) + try: + leg_amounts = [float(x) for x in json.loads(row["leg_amounts_json"] or "[]")] + except Exception: + leg_amounts = [] + try: + grid = json.loads(row["grid_prices_json"] or "[]") + except Exception: + grid = [] + avg_e = float(row["avg_entry_price"] or pf or 0) + hit_tp = _trend_hit_take_profit(direction, pf, tp, avg_e) + if hit_tp and pos > 0: + try: + close_resp = trend_market_close(cfg, ex_sym, direction, float(pos), lev) + exit_p = m.extract_trade_price_from_order(close_resp) or pf + except Exception as e: + if not m.is_no_position_error(str(e)): + continue + exit_p = pf + _finalize_plan(cfg, conn, row, "止盈", exit_p) + _TREND_FLAT_STREAK.pop(plan_id, None) + continue + if _should_finalize_trend_flat(row, pos, plan_id, m): + _finalize_plan(cfg, conn, row, "止损", pf) + _TREND_FLAT_STREAK.pop(plan_id, None) + continue + if int(row["first_order_done"] or 0) and legs_done < len(grid) and legs_done < len(leg_amounts): + while legs_done < len(grid) and legs_done < len(leg_amounts): + level = float(grid[legs_done]) + if not trend_dca_level_reached(direction, pf, level): + break + amt = float(m.exchange.amount_to_precision(ex_sym, leg_amounts[legs_done])) + if amt <= 0: + print( + f"[trend_pullback] dca skip plan={plan_id} leg={legs_done + 1} " + f"amt_precision=0 raw={leg_amounts[legs_done]}", + flush=True, + ) + break + try: + add_resp = trend_market_add(cfg, ex_sym, direction, amt, lev) + except Exception as e: + print( + f"[trend_pullback] dca order failed plan={plan_id} sym={sym} " + f"leg={legs_done + 1} level={level} mark={pf} err={e}", + flush=True, + ) + break + fill_px = m.extract_trade_price_from_order(add_resp) or pf + old_avg = float(row["avg_entry_price"] or fill_px) + old_open = float(row["order_amount_open"] or 0) + new_avg = _weighted_avg(old_avg, old_open, fill_px, amt) + legs_done += 1 + from lib.strategy.strategy_trend_lib import append_leg_fill_price_json + + fills_json = append_leg_fill_price_json( + row["leg_fill_prices_json"] if "leg_fill_prices_json" in row.keys() else None, + fill_px, + ) + conn.execute( + "UPDATE trend_pullback_plans SET legs_done=?, avg_entry_price=?, " + "order_amount_open=?, last_mark_price=?, leg_fill_prices_json=? WHERE id=?", + (legs_done, new_avg, old_open + amt, pf, fills_json, row["id"]), + ) + row = conn.execute( + "SELECT * FROM trend_pullback_plans WHERE id=?", (row["id"],) + ).fetchone() + print( + f"[trend_pullback] dca filled plan={plan_id} leg={legs_done} " + f"fill={fill_px} avg={new_avg} open={old_open + amt}", + flush=True, + ) + try: + trend_refresh_stop_only(cfg, ex_sym, direction, sl) + except Exception: + pass + conn.execute( + "UPDATE trend_pullback_plans SET last_mark_price=? WHERE id=?", + (pf, row["id"]), + ) + probe = summarize_trend_dca_probe(cfg, row) + probe["last_poll_mark"] = pf + _set_trend_poll_plan(plan_id, probe) + if probe.get("trigger_reached") and probe.get("block_reason"): + print( + f"[trend_pullback] dca blocked plan={plan_id} sym={sym} " + f"mark={pf} next={probe.get('next_trigger')} reason={probe.get('block_reason')}", + flush=True, + ) + except Exception as e: + print( + f"[trend_pullback] poll error plan={row['id'] if row else '?'}: {e}", + flush=True, + ) + continue + conn.commit() + conn.close() + + +TREND_PLAN_STATUS_HANDOFF = "stopped_handoff" + + +def _order_monitor_manual_type(m) -> str: + return getattr(m, "ORDER_MONITOR_TYPE_MANUAL", None) or "下单监控" + + +def _insert_trend_handoff_order_monitor( + cfg: dict, + conn, + plan_row, + *, + new_sl: float, + pos_amt: float, +) -> int: + m = _m(cfg) + sym = plan_row["symbol"] + direction = (plan_row["direction"] or "long").lower() + ex_sym = plan_row["exchange_symbol"] or m.normalize_exchange_symbol(sym) + plan_id = int(plan_row["id"]) + avg_e = float(plan_row["avg_entry_price"] or 0) + tp = float(plan_row["take_profit"] or 0) + lev = int(plan_row["leverage"] or 1) + margin_cap = float(plan_row["plan_margin_capital"] or 0) + init_sl = float( + plan_row["initial_stop_loss"] + if plan_row["initial_stop_loss"] not in (None, "") + else plan_row["stop_loss"] + or 0 + ) + risk_pct = float(plan_row["risk_percent"] or 5) + risk_amt = None + calc_risk = getattr(m, "calc_risk_amount_from_plan", None) + if callable(calc_risk): + try: + risk_amt = calc_risk(direction, avg_e, init_sl, margin_cap, lev) + except Exception: + risk_amt = None + be_rr = float(getattr(m, "BREAKEVEN_RR_TRIGGER", 1) or 1) + be_off = float(getattr(m, "BREAKEVEN_OFFSET_PCT", 0.3) or 0.3) + be_step = float(getattr(m, "BREAKEVEN_STEP_R", 1) or 1) + if direction == "short": + be_price = round(avg_e * (1 - be_off / 100.0), 8) + else: + be_price = round(avg_e * (1 + be_off / 100.0), 8) + rp = getattr(m, "round_price_to_exchange", None) + if callable(rp): + try: + be_price = float(rp(ex_sym, be_price) or be_price) + except Exception: + pass + opened_at = plan_row["opened_at"] or m.app_now_str() + to_ms = getattr(m, "_to_ms_with_fallback", None) + opened_ms = to_ms(plan_row["opened_at_ms"] if "opened_at_ms" in plan_row.keys() else None, opened_at) if callable(to_ms) else None + trading_day = plan_row["session_date"] or getattr(m, "get_trading_day", lambda: None)() + if not trading_day and callable(getattr(m, "get_trading_day", None)): + trading_day = m.get_trading_day() + notional = margin_cap * lev if margin_cap and lev else None + monitor_type = MONITOR_TYPE_TREND_PULLBACK + conn.execute( + "INSERT INTO order_monitors " + "(symbol, exchange_symbol, direction, trigger_price, stop_loss, initial_stop_loss, take_profit, " + "margin_capital, leverage, trade_style, risk_percent, risk_amount, " + "breakeven_rr_trigger, breakeven_offset_pct, breakeven_step_r, breakeven_armed, breakeven_price, " + "breakeven_enabled, notional_value, position_ratio, base_amount, order_amount, exchange_order_id, " + "opened_at, opened_at_ms, session_date, monitor_type, key_signal_type, trend_plan_id) " + "VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + ( + sym, + ex_sym, + direction, + avg_e, + new_sl, + init_sl, + tp, + margin_cap, + lev, + "trend_pullback_handoff", + risk_pct, + risk_amt, + be_rr, + be_off, + be_step, + 0, + be_price, + 0, + notional, + None, + None, + float(pos_amt), + "", + opened_at, + opened_ms, + trading_day, + monitor_type, + TREND_HANDOFF_KEY_SIGNAL, + plan_id, + ), + ) + new_id = int(conn.execute("SELECT last_insert_rowid()").fetchone()[0]) + persist = getattr(m, "try_persist_exchange_margin_for_order", None) + if callable(persist): + try: + persist(conn, new_id, ex_sym, direction, order_leverage=lev) + except Exception: + pass + return new_id + + +def apply_manual_breakeven(cfg: dict, conn, row, offset_pct=None) -> tuple[bool, Optional[str]]: + """保本:结束趋势计划,持仓移交下单监控(备注趋势回调),交易所同时挂保本止损与止盈。""" + m = _m(cfg) + if (row["status"] or "").strip() != "active": + return False, "计划已结束" + if not int(row["first_order_done"] or 0): + return False, "尚未完成首仓,无法保本" + avg_e = float(row["avg_entry_price"] or 0) + if avg_e <= 0: + return False, "缺少有效持仓均价" + direction = (row["direction"] or "long").lower() + sym = row["symbol"] + ex_sym = row["exchange_symbol"] or m.normalize_exchange_symbol(sym) + pos = m.get_live_position_contracts(ex_sym, direction) + if pos is None or float(pos) <= 0: + return False, "交易所当前无该方向持仓" + pos_amt = float(pos) + dup = conn.execute( + "SELECT id FROM order_monitors WHERE status='active' AND symbol=? AND direction=? LIMIT 1", + (sym, direction), + ).fetchone() + if dup: + return False, "该币种已有运行中的下单监控,请先结束后再保本移交" + be_fn = getattr(m, "calc_trend_manual_breakeven_stop", None) + if not callable(be_fn): + pct = float(offset_pct if offset_pct is not None else cfg["breakeven_offset_pct"]) + if direction == "short": + new_sl_raw = avg_e * (1.0 - pct / 100.0) + else: + new_sl_raw = avg_e * (1.0 + pct / 100.0) + else: + new_sl_raw = be_fn(direction, avg_e, offset_pct) + if new_sl_raw is None: + return False, "保本价计算失败" + new_sl = m.round_price_to_exchange(ex_sym, new_sl_raw) + if new_sl is None: + return False, "保本价经交易所精度舍入后无效" + new_sl = float(new_sl) + tp = float(row["take_profit"] or 0) + if tp <= 0: + return False, "计划止盈价无效" + cur_sl = float(row["stop_loss"] or 0) + if direction == "long": + if new_sl <= cur_sl: + return False, f"新止损 {new_sl} 未高于当前止损 {cur_sl}(多仓需上移)" + else: + if new_sl >= cur_sl: + return False, f"新止损 {new_sl} 未低于当前止损 {cur_sl}(空仓需下移)" + ok_live, live_reason = m.ensure_exchange_live_ready() + if not ok_live: + return False, live_reason or "实盘未就绪" + plan_id = int(row["id"]) + try: + from lib.strategy.strategy_snapshot_lib import save_trend_plan_snapshot + + save_trend_plan_snapshot( + cfg, conn, row, result_label="保本移交", exit_price=None, pnl_amount=None + ) + except Exception: + pass + handoff_row = { + "symbol": sym, + "exchange_symbol": ex_sym, + "direction": direction, + "order_amount": pos_amt, + } + try: + trend_replace_tpsl(cfg, handoff_row, new_sl, tp) + except Exception as e: + fe = getattr(m, "friendly_exchange_error", None) + return False, fe(e) if callable(fe) else str(e) + now_s = m.app_now_str() + _TREND_FLAT_STREAK.pop(plan_id, None) + cur = conn.execute( + "UPDATE trend_pullback_plans SET status=?, message=?, stop_loss=?, " + "breakeven_applied=1, breakeven_applied_at=? WHERE id=? AND status='active'", + ( + TREND_PLAN_STATUS_HANDOFF, + f"保本移交下单监控({TREND_HANDOFF_TRADE_NOTE})", + new_sl, + now_s, + plan_id, + ), + ) + if not getattr(cur, "rowcount", 0): + return False, "计划状态更新失败(可能已被其他操作结束)" + try: + mon_id = _insert_trend_handoff_order_monitor( + cfg, conn, row, new_sl=new_sl, pos_amt=pos_amt + ) + except Exception as e: + conn.execute( + "UPDATE trend_pullback_plans SET status='active', message=? WHERE id=?", + (f"移交下单监控失败:{e}", plan_id), + ) + return False, f"移交下单监控失败:{e}" + pct_used = float( + offset_pct if offset_pct is not None else cfg["breakeven_offset_pct"] + ) + extra = getattr(m, "build_wechat_close_message", None) + send = getattr(m, "send_wechat_msg", None) + pf = getattr(m, "format_price_for_symbol", None) + fmt = (lambda s, p: pf(s, p)) if callable(pf) else (lambda _s, p: str(p)) + try: + from lib.strategy.strategy_wechat_notify import notify_trend_plan_ended + + notify_trend_plan_ended( + cfg, + plan_id=plan_id, + symbol=sym, + direction=direction, + end_type="保本移交", + result_label=TREND_HANDOFF_TRADE_NOTE, + extra=f"已移交下单监控 #{mon_id};止损 {fmt(sym, new_sl)} | 止盈 {fmt(sym, tp)}", + ) + except Exception: + pass + if callable(send): + lines = [ + f"# ✅ {sym} 趋势回调保本移交", + f"- 计划 ID:**{plan_id}** → 下单监控 **#{mon_id}**", + f"- 备注:**{TREND_HANDOFF_TRADE_NOTE}**", + f"- 保本止损:{fmt(sym, new_sl)} | 止盈:{fmt(sym, tp)}", + f"- 交易所:已挂止盈止损;平仓后将写入交易记录({ENTRY_REASON_TREND_PULLBACK})", + ] + wl = getattr(m, "_wechat_account_label", None) + if callable(wl): + lines.insert(1, f"**账户:{wl()}**") + send("\n".join(lines)) + return True, None + + +def load_trend_page_context(conn, request_obj, cfg: dict) -> dict[str, Any]: + m = _m(cfg) + _cleanup_stale_previews(conn) + trend_active = int( + conn.execute( + "SELECT COUNT(*) FROM trend_pullback_plans WHERE status='active'" + ).fetchone()[0] + or 0 + ) + trend_plans = [] + trend_dca_probes = [] + raw_plans = conn.execute( + "SELECT * FROM trend_pullback_plans WHERE status='active' ORDER BY id DESC" + ).fetchall() + for r in raw_plans: + try: + enriched = enrich_trend_plan(cfg, r) + trend_plans.append(enriched) + except Exception: + enriched = _row(cfg, r) + trend_plans.append(enriched) + try: + probe = summarize_trend_dca_probe(cfg, r) + trend_dca_probes.append(probe) + if isinstance(enriched, dict): + enriched["dca_probe"] = probe + except Exception: + pass + now = m.app_now() + active_count = m.get_active_position_count(conn) + from lib.trade.daily_open_limit_lib import can_trade_new_open, count_opens_for_trading_day + + trading_day = m.get_trading_day(now) + opens_today = count_opens_for_trading_day(conn, trading_day) + hard_limit = int(getattr(m, "DAILY_OPEN_HARD_LIMIT", 0) or 0) + can_trade_trend = can_trade_new_open( + time_allows=m.trading_day_reset_allows_new_open(now), + active_count=active_count, + max_active_positions=cfg["max_active_positions"], + opens_today=opens_today, + hard_limit=hard_limit, + extra_blocks=trend_active != 0, + ) + trend_preview = None + trend_preview_levels = [] + preview_expires_ms = None + trend_preview_expired = False + pid_arg = (request_obj.args.get("preview_id") or "").strip() + if pid_arg: + pr = conn.execute( + "SELECT * FROM trend_pullback_previews WHERE id=?", (pid_arg,) + ).fetchone() + now_ms = int(time.time() * 1000) + if pr and int(pr["expires_at_ms"] or 0) >= now_ms: + from lib.strategy.strategy_trend_lib import build_trend_preview_level_rows + + trend_preview = _row(cfg, pr) + preview_expires_ms = int(pr["expires_at_ms"]) + get_cs = getattr(m, "get_contract_size", None) + if callable(get_cs) and not trend_preview.get("contract_size"): + try: + trend_preview["contract_size"] = float( + get_cs(trend_preview.get("exchange_symbol") or trend_preview.get("symbol") or "") + ) + except (TypeError, ValueError): + pass + trend_preview, trend_preview_levels = build_trend_preview_level_rows(trend_preview) + elif pr: + trend_preview_expired = True + return { + "trend_plans": trend_plans, + "trend_dca_probes": trend_dca_probes, + "trend_active": trend_active, + "can_trade_trend": can_trade_trend, + "trend_preview": trend_preview, + "trend_preview_levels": trend_preview_levels, + "preview_expires_ms": preview_expires_ms, + "trend_preview_expired": trend_preview_expired, + "trend_pullback_dca_legs": cfg["dca_legs"], + "trend_pullback_preview_ttl": cfg["preview_ttl"], + "trend_preview_max_drift_pct": cfg["drift_pct"], + "trend_manual_breakeven_offset_pct": cfg["breakeven_offset_pct"], + } + + +def register_trend_routes(app: Flask, cfg: dict) -> None: + lr = cfg["login_required"] + get_db = cfg["get_db"] + + def _redirect_trend(**kw): + return redirect(url_for("strategy_trading_page", **kw)) + + @app.route("/preview_trend_pullback", methods=["POST"]) + @lr + def preview_trend_pullback(): + conn = get_db() + init_strategy_tables(conn) + okp, msg = precheck_trend_start(cfg, conn) + if not okp: + conn.close() + flash(msg) + return _redirect_trend() + m = _m(cfg) + ok_live, reason = m.ensure_exchange_live_ready() + if not ok_live: + conn.close() + flash(reason) + return _redirect_trend() + payload, err = parse_trend_plan(cfg, request.form) + if err: + conn.close() + flash(err) + return _redirect_trend() + pid = str(uuid.uuid4()) + exp_ms = int(time.time() * 1000) + cfg["preview_ttl"] * 1000 + created = m.app_now_str() + conn.execute( + """INSERT INTO trend_pullback_previews ( + id,symbol,exchange_symbol,direction,leverage,stop_loss,add_upper,take_profit,risk_percent, + snapshot_available_usdt,snapshot_at,live_price_ref,plan_margin_capital,target_order_amount,first_order_amount,remainder_total, + dca_legs,per_leg_amount,grid_prices_json,leg_amounts_json,expires_at_ms,created_at + ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", + ( + pid, + payload["symbol"], + payload["exchange_symbol"], + payload["direction"], + payload["leverage"], + payload["stop_loss"], + payload["add_upper"], + payload["take_profit"], + payload["risk_percent"], + payload["snapshot_available_usdt"], + payload["snapshot_at"], + payload["live_price_ref"], + payload["plan_margin_capital"], + payload["target_order_amount"], + payload["first_order_amount"], + payload["remainder_total"], + payload["dca_legs"], + payload["per_leg_amount"], + payload["grid_prices_json"], + payload["leg_amounts_json"], + exp_ms, + created, + ), + ) + _insert_preview_snapshot(conn, pid, created, exp_ms, payload) + conn.commit() + conn.close() + flash(f"预览已生成,有效期 {cfg['preview_ttl']} 秒,请核对后点击「确认执行」。") + return _redirect_trend(preview_id=pid) + + @app.route("/execute_trend_pullback", methods=["POST"]) + @lr + def execute_trend_pullback(): + pid = (request.form.get("preview_id") or "").strip() + if not pid: + flash("缺少预览 ID") + return _redirect_trend() + conn = get_db() + init_strategy_tables(conn) + _cleanup_stale_previews(conn) + pr = conn.execute( + "SELECT * FROM trend_pullback_previews WHERE id=?", (pid,) + ).fetchone() + now_ms = int(time.time() * 1000) + if not pr or int(pr["expires_at_ms"] or 0) < now_ms: + conn.close() + flash("预览已过期或不存在,请重新生成预览") + return _redirect_trend() + okp, msg = precheck_trend_start(cfg, conn) + if not okp: + conn.close() + flash(msg) + return _redirect_trend(preview_id=pid) + m = _m(cfg) + ok_live, reason = m.ensure_exchange_live_ready() + if not ok_live: + conn.close() + flash(reason) + return _redirect_trend(preview_id=pid) + snap_prev = float(pr["snapshot_available_usdt"] or 0) + snap_now = m.get_available_trading_usdt() + if snap_now is None or snap_now <= 0: + conn.close() + flash("无法读取当前合约可用余额,请稍后重试") + return _redirect_trend(preview_id=pid) + drift = abs(float(snap_now) - snap_prev) / max(snap_prev, 1e-9) * 100.0 + if drift > cfg["drift_pct"]: + conn.close() + flash( + f"当前可用余额与预览快照偏差 {drift:.2f}%,超过允许 {cfg['drift_pct']}%,请重新生成预览" + ) + return _redirect_trend(preview_id=pid) + symbol = pr["symbol"] + exchange_symbol = pr["exchange_symbol"] + direction = pr["direction"] or "long" + leverage = int(pr["leverage"] or 1) + stop_loss = float(pr["stop_loss"]) + first_amt = float(pr["first_order_amount"] or 0) + live_price = m.get_price(symbol) + if live_price is None: + conn.close() + flash("获取实时价格失败") + return _redirect_trend(preview_id=pid) + try: + o1 = m.place_exchange_order( + exchange_symbol, direction, first_amt, leverage, stop_loss=None, take_profit=None + ) + fill1 = m.resolve_order_entry_price(o1, exchange_symbol, live_price) + trend_refresh_stop_only(cfg, exchange_symbol, direction, stop_loss) + except Exception as e: + conn.close() + fe = getattr(m, "friendly_exchange_error", lambda x, **k: str(x)) + flash(fe(e, available_usdt=snap_now)) + return _redirect_trend(preview_id=pid) + trading_day = m.get_trading_day(m.app_now()) + opened_at = m.app_now_str() + opened_ms = getattr(m, "_to_ms_with_fallback", lambda a, b: None)(None, opened_at) + from lib.strategy.strategy_trend_lib import append_leg_fill_price_json + + fills_json = append_leg_fill_price_json(None, fill1) + cur = conn.execute( + """INSERT INTO trend_pullback_plans ( + status,symbol,exchange_symbol,direction,leverage,stop_loss,initial_stop_loss,add_upper,take_profit,risk_percent, + snapshot_available_usdt,snapshot_at,plan_margin_capital,target_order_amount,first_order_amount,remainder_total, + dca_legs,per_leg_amount,grid_prices_json,leg_amounts_json,legs_done,first_order_done,last_mark_price,avg_entry_price,order_amount_open,opened_at,opened_at_ms,session_date,message,leg_fill_prices_json + ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", + ( + "active", + symbol, + exchange_symbol, + direction, + leverage, + stop_loss, + stop_loss, + float(pr["add_upper"]), + float(pr["take_profit"]), + float(pr["risk_percent"] or 5), + float(snap_now), + opened_at, + float(pr["plan_margin_capital"] or 0), + float(pr["target_order_amount"] or 0), + first_amt, + float(pr["remainder_total"] or 0), + int(pr["dca_legs"] or 0), + float(pr["per_leg_amount"] or 0), + pr["grid_prices_json"] or "[]", + pr["leg_amounts_json"] or "[]", + 0, + 1, + float(live_price), + fill1, + first_amt, + opened_at, + opened_ms, + trading_day, + f"预览ID:{pid[:8]}…", + fills_json, + ), + ) + new_id = int(cur.lastrowid) + conn.execute( + "UPDATE trend_pullback_preview_snapshots SET outcome='executed', executed_plan_id=? WHERE preview_id=?", + (new_id, pid), + ) + conn.execute("DELETE FROM trend_pullback_previews WHERE id=?", (pid,)) + conn.commit() + try: + from lib.strategy.strategy_wechat_notify import notify_trend_plan_started + + notify_trend_plan_started( + cfg, + plan_id=new_id, + symbol=symbol, + direction=direction, + leverage=leverage, + stop_loss=stop_loss, + take_profit=float(pr["take_profit"]), + add_upper=float(pr["add_upper"]), + risk_percent=float(pr["risk_percent"] or 5), + dca_legs=int(pr["dca_legs"] or 0), + first_order_amount=first_amt, + avg_entry=fill1, + snapshot_usdt=float(snap_now), + ) + except Exception: + pass + conn.close() + flash("趋势回调已执行:首仓已成交并挂交易所止损,止盈由程序监控。") + return _redirect_trend() + + @app.route("/cancel_trend_pullback_preview", methods=["POST"]) + @lr + def cancel_trend_pullback_preview(): + pid = (request.form.get("preview_id") or "").strip() + conn = get_db() + if pid: + conn.execute( + "UPDATE trend_pullback_preview_snapshots SET outcome='cancelled' WHERE preview_id=? AND outcome='open'", + (pid,), + ) + conn.execute("DELETE FROM trend_pullback_previews WHERE id=?", (pid,)) + conn.commit() + conn.close() + flash("已取消预览") + return _redirect_trend() + + @app.route("/trend_pullback_breakeven/", methods=["POST"]) + @lr + def trend_pullback_breakeven(pid: int): + offset_pct = None + raw = (request.form.get("breakeven_offset_pct") or "").strip() + if raw: + try: + offset_pct = float(raw) + if offset_pct < 0: + raise ValueError + except ValueError: + flash("保本偏移% 格式无效") + return _redirect_trend() + conn = get_db() + row = conn.execute( + "SELECT * FROM trend_pullback_plans WHERE id=? AND status='active'", (pid,) + ).fetchone() + if not row: + conn.close() + flash("未找到运行中的趋势回调计划") + return _redirect_trend() + ok, err = apply_manual_breakeven(cfg, conn, row, offset_pct=offset_pct) + conn.commit() + conn.close() + flash( + "已保本:趋势计划已结束,持仓已移交下单监控并挂止盈止损;平仓后将写入交易记录" + if ok + else (err or "保本移交失败") + ) + return _redirect_trend() + + @app.route("/stop_trend_pullback/") + @lr + def stop_trend_pullback(pid: int): + conn = get_db() + row = conn.execute( + "SELECT * FROM trend_pullback_plans WHERE id=? AND status='active'", (pid,) + ).fetchone() + if not row: + stopped = conn.execute( + "SELECT * FROM trend_pullback_plans WHERE id=? " + "AND status IN ('stopped_sl','stopped_tp','stopped_manual')", + (pid,), + ).fetchone() + if stopped and not _trend_plan_trade_exists(conn, pid): + try: + if _ensure_trend_plan_trade_record(cfg, conn, pid, prefer_label="手动平仓"): + conn.close() + flash("计划已结束,已补录缺失的交易记录") + return _redirect_trend() + except Exception as e: + conn.close() + flash(f"补录交易记录失败:{e}") + return _redirect_trend() + conn.close() + flash("未找到运行中的趋势回调计划") + return _redirect_trend() + m = _m(cfg) + ex_sym = row["exchange_symbol"] or m.normalize_exchange_symbol(row["symbol"]) + direction = row["direction"] or "long" + lev = int(row["leverage"] or 1) + px = m.get_price(row["symbol"]) + exit_p = float(px) if px is not None else 0.0 + ok_live, _ = m.ensure_exchange_live_ready() + if ok_live: + pos = m.get_live_position_contracts(ex_sym, direction) + if pos is not None and pos > 0: + try: + close_resp = trend_market_close(cfg, ex_sym, direction, float(pos), lev) + ep = m.extract_trade_price_from_order(close_resp) + if ep: + exit_p = float(ep) + except Exception as e: + if not m.is_no_position_error(str(e)): + conn.close() + flash(f"平仓失败:{e}") + return _redirect_trend() + try: + cancel_symbol_orders(cfg, ex_sym) + except Exception: + pass + try: + _finalize_plan(cfg, conn, row, "手动平仓", exit_p, user_initiated_risk=True) + except Exception as e: + conn.execute( + "UPDATE trend_pullback_plans SET status='stopped_manual', message=? " + "WHERE id=? AND status='active'", + (f"结束异常:{e}", pid), + ) + conn.commit() + conn.close() + flash(f"计划已结束但记账可能不完整:{e}") + return _redirect_trend() + conn.close() + flash("已结束趋势回调计划") + return _redirect_trend() diff --git a/strategy_ui.py b/lib/strategy/strategy_ui.py similarity index 89% rename from strategy_ui.py rename to lib/strategy/strategy_ui.py index 52a877a..92ddc6f 100644 --- a/strategy_ui.py +++ b/lib/strategy/strategy_ui.py @@ -1,144 +1,144 @@ -"""策略交易页:主站 index.html 所需数据(顺势加仓等)。""" -from __future__ import annotations - -from typing import Any, Callable, Optional - -from strategy_db import init_strategy_tables -from strategy_roll_monitor_lib import roll_leg_status_label - - -def _row_to_dict(row) -> dict: - if row is None: - return {} - try: - return dict(row) - except Exception: - return {} - - -def count_active_trend_plans(conn, count_fn: Optional[Callable] = None) -> int: - if callable(count_fn): - return int(count_fn(conn) or 0) - try: - return int( - conn.execute( - "SELECT COUNT(*) FROM trend_pullback_plans WHERE status='active'" - ).fetchone()[0] - ) - except Exception: - return 0 - - -def fetch_roll_page_data( - conn, - *, - default_risk_percent: float = 2.0, - count_active_trends: Optional[Callable] = None, - roll_cfg: dict | None = None, -) -> dict[str, Any]: - init_strategy_tables(conn) - monitors = [] - for row in conn.execute( - "SELECT * FROM order_monitors WHERE status='active' ORDER BY id DESC" - ).fetchall(): - monitors.append(_row_to_dict(row)) - roll_groups = [] - for row in conn.execute( - """SELECT g.* FROM roll_groups g - INNER JOIN order_monitors m ON m.id = g.order_monitor_id AND m.status='active' - WHERE g.status='active' - ORDER BY g.id DESC""" - ).fetchall(): - roll_groups.append(_row_to_dict(row)) - active_gids = {int(g["id"]) for g in roll_groups if g.get("id") is not None} - roll_legs = [] - for row in conn.execute( - "SELECT * FROM roll_legs ORDER BY id DESC LIMIT 80" - ).fetchall(): - leg = _row_to_dict(row) - gid = leg.get("roll_group_id") - if gid is not None and int(gid) not in active_gids: - continue - leg["status_label"] = roll_leg_status_label(leg.get("status")) - roll_legs.append(leg) - roll_legs = roll_legs[:50] - out = { - "roll_monitors": monitors, - "roll_groups": roll_groups, - "roll_legs": roll_legs, - "roll_trend_active": count_active_trend_plans(conn, count_active_trends), - "default_risk_percent": default_risk_percent, - } - if roll_cfg: - from strategy_roll_ui_lib import enrich_roll_page_data - - enrich_roll_page_data(conn, out, roll_cfg) - return out - - -DEFAULT_TREND_DISABLED_NOTE = ( - "趋势回调(预览、自动补仓、程序止盈)仅在 Gate 趋势机器人实例 " - "(crypto_monitor_gate_bot,常见端口 5002)中启用。" - "币安 / Gate 主站 / OKX 可使用本页「顺势加仓」;完整趋势回调请打开该实例。" -) - - -def strategy_render_extras( - conn, - page: str, - *, - default_risk_percent: float = 2.0, - count_active_trends: Optional[Callable] = None, - trend_disabled_note: str = "", - request_obj=None, - trend_cfg: Optional[dict] = None, -) -> dict[str, Any]: - """render_main_page 策略相关页变量(含策略交易记录)。""" - if page == "strategy_records": - from strategy_records_register import load_strategy_records_page - - return load_strategy_records_page(conn) - return strategy_page_template_vars( - conn, - page, - default_risk_percent=default_risk_percent, - count_active_trends=count_active_trends, - trend_disabled_note=trend_disabled_note, - request_obj=request_obj, - trend_cfg=trend_cfg, - ) - - -def strategy_page_template_vars( - conn, - page: str, - *, - default_risk_percent: float = 2.0, - count_active_trends: Optional[Callable] = None, - trend_disabled_note: str = "", - request_obj=None, - trend_cfg: Optional[dict] = None, -) -> dict[str, Any]: - """render_main_page 在 conn.close() 前合并进 render_template 的变量。""" - if page not in ("strategy", "strategy_trend", "strategy_roll"): - return {} - roll_cfg = None - try: - from flask import current_app - - roll_cfg = (current_app.extensions or {}).get("strategy_roll_cfg") - except Exception: - roll_cfg = None - out = fetch_roll_page_data( - conn, - default_risk_percent=default_risk_percent, - count_active_trends=count_active_trends, - roll_cfg=roll_cfg if isinstance(roll_cfg, dict) else None, - ) - if trend_cfg and request_obj is not None: - from strategy_trend_register import load_trend_page_context - - out.update(load_trend_page_context(conn, request_obj, trend_cfg)) - elif page == "strategy_trend": - out["trend_disabled_note"] = trend_disabled_note or DEFAULT_TREND_DISABLED_NOTE - return out +"""策略交易页:主站 index.html 所需数据(顺势加仓等)。""" +from __future__ import annotations + +from typing import Any, Callable, Optional + +from lib.strategy.strategy_db import init_strategy_tables +from lib.strategy.strategy_roll_monitor_lib import roll_leg_status_label + + +def _row_to_dict(row) -> dict: + if row is None: + return {} + try: + return dict(row) + except Exception: + return {} + + +def count_active_trend_plans(conn, count_fn: Optional[Callable] = None) -> int: + if callable(count_fn): + return int(count_fn(conn) or 0) + try: + return int( + conn.execute( + "SELECT COUNT(*) FROM trend_pullback_plans WHERE status='active'" + ).fetchone()[0] + ) + except Exception: + return 0 + + +def fetch_roll_page_data( + conn, + *, + default_risk_percent: float = 2.0, + count_active_trends: Optional[Callable] = None, + roll_cfg: dict | None = None, +) -> dict[str, Any]: + init_strategy_tables(conn) + monitors = [] + for row in conn.execute( + "SELECT * FROM order_monitors WHERE status='active' ORDER BY id DESC" + ).fetchall(): + monitors.append(_row_to_dict(row)) + roll_groups = [] + for row in conn.execute( + """SELECT g.* FROM roll_groups g + INNER JOIN order_monitors m ON m.id = g.order_monitor_id AND m.status='active' + WHERE g.status='active' + ORDER BY g.id DESC""" + ).fetchall(): + roll_groups.append(_row_to_dict(row)) + active_gids = {int(g["id"]) for g in roll_groups if g.get("id") is not None} + roll_legs = [] + for row in conn.execute( + "SELECT * FROM roll_legs ORDER BY id DESC LIMIT 80" + ).fetchall(): + leg = _row_to_dict(row) + gid = leg.get("roll_group_id") + if gid is not None and int(gid) not in active_gids: + continue + leg["status_label"] = roll_leg_status_label(leg.get("status")) + roll_legs.append(leg) + roll_legs = roll_legs[:50] + out = { + "roll_monitors": monitors, + "roll_groups": roll_groups, + "roll_legs": roll_legs, + "roll_trend_active": count_active_trend_plans(conn, count_active_trends), + "default_risk_percent": default_risk_percent, + } + if roll_cfg: + from lib.strategy.strategy_roll_ui_lib import enrich_roll_page_data + + enrich_roll_page_data(conn, out, roll_cfg) + return out + + +DEFAULT_TREND_DISABLED_NOTE = ( + "趋势回调(预览、自动补仓、程序止盈)仅在 Gate 趋势机器人实例 " + "(crypto_monitor_gate_bot,常见端口 5002)中启用。" + "币安 / Gate 主站 / OKX 可使用本页「顺势加仓」;完整趋势回调请打开该实例。" +) + + +def strategy_render_extras( + conn, + page: str, + *, + default_risk_percent: float = 2.0, + count_active_trends: Optional[Callable] = None, + trend_disabled_note: str = "", + request_obj=None, + trend_cfg: Optional[dict] = None, +) -> dict[str, Any]: + """render_main_page 策略相关页变量(含策略交易记录)。""" + if page == "strategy_records": + from lib.strategy.strategy_records_register import load_strategy_records_page + + return load_strategy_records_page(conn) + return strategy_page_template_vars( + conn, + page, + default_risk_percent=default_risk_percent, + count_active_trends=count_active_trends, + trend_disabled_note=trend_disabled_note, + request_obj=request_obj, + trend_cfg=trend_cfg, + ) + + +def strategy_page_template_vars( + conn, + page: str, + *, + default_risk_percent: float = 2.0, + count_active_trends: Optional[Callable] = None, + trend_disabled_note: str = "", + request_obj=None, + trend_cfg: Optional[dict] = None, +) -> dict[str, Any]: + """render_main_page 在 conn.close() 前合并进 render_template 的变量。""" + if page not in ("strategy", "strategy_trend", "strategy_roll"): + return {} + roll_cfg = None + try: + from flask import current_app + + roll_cfg = (current_app.extensions or {}).get("strategy_roll_cfg") + except Exception: + roll_cfg = None + out = fetch_roll_page_data( + conn, + default_risk_percent=default_risk_percent, + count_active_trends=count_active_trends, + roll_cfg=roll_cfg if isinstance(roll_cfg, dict) else None, + ) + if trend_cfg and request_obj is not None: + from lib.strategy.strategy_trend_register import load_trend_page_context + + out.update(load_trend_page_context(conn, request_obj, trend_cfg)) + elif page == "strategy_trend": + out["trend_disabled_note"] = trend_disabled_note or DEFAULT_TREND_DISABLED_NOTE + return out diff --git a/strategy_wechat_notify.py b/lib/strategy/strategy_wechat_notify.py similarity index 95% rename from strategy_wechat_notify.py rename to lib/strategy/strategy_wechat_notify.py index dabaabf..79e5a9e 100644 --- a/strategy_wechat_notify.py +++ b/lib/strategy/strategy_wechat_notify.py @@ -1,192 +1,192 @@ -"""策略计划(趋势回调 / 滚仓)开始与结束 — 企业微信推送(四所共用)。""" -from __future__ import annotations - -from typing import Any, Optional - -from wechat_notify_lib import wechat_direction_label - - -def _send(cfg: dict[str, Any], content: str) -> None: - fn = cfg.get("send_wechat") - if callable(fn): - try: - fn(content) - return - except Exception: - pass - m = cfg.get("app_module") - if m is not None: - sw = getattr(m, "send_wechat_msg", None) - if callable(sw): - try: - sw(content) - except Exception: - pass - - -def _account(cfg: dict[str, Any]) -> str: - fn = cfg.get("wechat_account_label") - if callable(fn): - try: - return str(fn()).strip() or _exchange(cfg) - except Exception: - pass - return _exchange(cfg) - - -def _exchange(cfg: dict[str, Any]) -> str: - return str(cfg.get("exchange_display") or "").strip() or "交易账户" - - -def _dir_text(cfg: dict[str, Any], direction: str) -> str: - fn = cfg.get("wechat_direction_text") - if callable(fn): - try: - return str(fn(direction)) - except Exception: - pass - return wechat_direction_label(direction) - - -def _fmt_price(cfg: dict[str, Any], symbol: str, price: Any) -> str: - if price is None or price == "": - return "—" - fn = cfg.get("format_price") or cfg.get("price_fmt") - if callable(fn): - try: - return str(fn(symbol, price)) - except Exception: - pass - m = cfg.get("app_module") - pf = getattr(m, "format_price_for_symbol", None) if m else None - if callable(pf): - try: - return str(pf(symbol, price)) - except Exception: - pass - try: - return str(round(float(price), 8)) - except (TypeError, ValueError): - return str(price) - - -def _fmt_pnl(pnl: Any) -> str: - if pnl is None: - return "—" - try: - v = float(pnl) - return f"{'+' if v > 0 else ''}{round(v, 2)} U" - except (TypeError, ValueError): - return str(pnl) - - -def notify_trend_plan_started( - cfg: dict[str, Any], - *, - plan_id: int, - symbol: str, - direction: str, - leverage: int, - stop_loss: float, - take_profit: float, - add_upper: float, - risk_percent: float, - dca_legs: int, - first_order_amount: float, - avg_entry: Optional[float] = None, - snapshot_usdt: Optional[float] = None, -) -> None: - sym = symbol or "—" - lines = [ - f"# 🚀 {sym} 趋势回调计划已开始", - f"**账户:{_account(cfg)}**", - f"- 计划 ID:**{plan_id}**", - f"- 方向:{_dir_text(cfg, direction)}|杠杆 **{int(leverage or 1)}x**", - f"- 止损:{_fmt_price(cfg, sym, stop_loss)}|止盈:{_fmt_price(cfg, sym, take_profit)}", - f"- 补仓区:{_fmt_price(cfg, sym, add_upper)}|补仓档 **{int(dca_legs or 0)}** 档", - f"- 风险:**{risk_percent}%**|首仓张数:**{first_order_amount}**", - ] - if avg_entry is not None: - lines.append(f"- 首仓成交价:{_fmt_price(cfg, sym, avg_entry)}") - if snapshot_usdt is not None: - try: - lines.append(f"- 启动时合约可用:**{round(float(snapshot_usdt), 2)} U**") - except (TypeError, ValueError): - pass - lines.append("- 说明:交易所已挂止损;止盈由程序监控;结束/保本将另行推送") - _send(cfg, "\n".join(lines)) - - -def notify_trend_plan_ended( - cfg: dict[str, Any], - *, - plan_id: int, - symbol: str, - direction: str, - end_type: str, - result_label: Optional[str] = None, - exit_price: Optional[float] = None, - pnl_amount: Optional[float] = None, - extra: Optional[str] = None, -) -> None: - sym = symbol or "—" - res = (result_label or end_type or "—").strip() - lines = [ - f"# 🏁 {sym} 趋势回调计划已结束", - f"**账户:{_account(cfg)}**", - f"- 计划 ID:**{plan_id}**", - f"- 方向:{_dir_text(cfg, direction)}", - f"- 结束方式:**{end_type}**", - f"- 结果:**{res}**", - ] - if exit_price is not None: - lines.append(f"- 离场参考价:{_fmt_price(cfg, sym, exit_price)}") - if pnl_amount is not None: - lines.append(f"- 本单盈亏:**{_fmt_pnl(pnl_amount)}**") - if extra: - lines.append(f"- {extra}") - _send(cfg, "\n".join(lines)) - - -def notify_roll_group_started( - cfg: dict[str, Any], - *, - group_id: int, - symbol: str, - direction: str, - order_monitor_id: int, - initial_take_profit: Optional[float] = None, - initial_stop_loss: Optional[float] = None, -) -> None: - sym = symbol or "—" - lines = [ - f"# 🚀 {sym} 滚仓计划已开始", - f"**账户:{_account(cfg)}**", - f"- 滚仓组 ID:**{group_id}**|绑定下单监控 **#{order_monitor_id}**", - f"- 方向:{_dir_text(cfg, direction)}", - f"- 首仓止盈(锁定):{_fmt_price(cfg, sym, initial_take_profit)}", - f"- 当前止损:{_fmt_price(cfg, sym, initial_stop_loss)}", - "- 说明:顺势加仓为人工触发;组结束(无持仓/监控结案)将另行推送", - ] - _send(cfg, "\n".join(lines)) - - -def notify_roll_group_ended( - cfg: dict[str, Any], - *, - group_id: int, - symbol: str, - direction: str, - reason: str, - leg_count: int = 0, -) -> None: - sym = symbol or "—" - lines = [ - f"# 🏁 {sym} 滚仓计划已结束", - f"**账户:{_account(cfg)}**", - f"- 滚仓组 ID:**{group_id}**", - f"- 方向:{_dir_text(cfg, direction)}", - f"- 结束原因:**{reason}**", - f"- 已完成滚仓腿数:**{int(leg_count or 0)}**", - ] - _send(cfg, "\n".join(lines)) +"""策略计划(趋势回调 / 滚仓)开始与结束 — 企业微信推送(四所共用)。""" +from __future__ import annotations + +from typing import Any, Optional + +from lib.common.wechat_notify_lib import wechat_direction_label + + +def _send(cfg: dict[str, Any], content: str) -> None: + fn = cfg.get("send_wechat") + if callable(fn): + try: + fn(content) + return + except Exception: + pass + m = cfg.get("app_module") + if m is not None: + sw = getattr(m, "send_wechat_msg", None) + if callable(sw): + try: + sw(content) + except Exception: + pass + + +def _account(cfg: dict[str, Any]) -> str: + fn = cfg.get("wechat_account_label") + if callable(fn): + try: + return str(fn()).strip() or _exchange(cfg) + except Exception: + pass + return _exchange(cfg) + + +def _exchange(cfg: dict[str, Any]) -> str: + return str(cfg.get("exchange_display") or "").strip() or "交易账户" + + +def _dir_text(cfg: dict[str, Any], direction: str) -> str: + fn = cfg.get("wechat_direction_text") + if callable(fn): + try: + return str(fn(direction)) + except Exception: + pass + return wechat_direction_label(direction) + + +def _fmt_price(cfg: dict[str, Any], symbol: str, price: Any) -> str: + if price is None or price == "": + return "—" + fn = cfg.get("format_price") or cfg.get("price_fmt") + if callable(fn): + try: + return str(fn(symbol, price)) + except Exception: + pass + m = cfg.get("app_module") + pf = getattr(m, "format_price_for_symbol", None) if m else None + if callable(pf): + try: + return str(pf(symbol, price)) + except Exception: + pass + try: + return str(round(float(price), 8)) + except (TypeError, ValueError): + return str(price) + + +def _fmt_pnl(pnl: Any) -> str: + if pnl is None: + return "—" + try: + v = float(pnl) + return f"{'+' if v > 0 else ''}{round(v, 2)} U" + except (TypeError, ValueError): + return str(pnl) + + +def notify_trend_plan_started( + cfg: dict[str, Any], + *, + plan_id: int, + symbol: str, + direction: str, + leverage: int, + stop_loss: float, + take_profit: float, + add_upper: float, + risk_percent: float, + dca_legs: int, + first_order_amount: float, + avg_entry: Optional[float] = None, + snapshot_usdt: Optional[float] = None, +) -> None: + sym = symbol or "—" + lines = [ + f"# 🚀 {sym} 趋势回调计划已开始", + f"**账户:{_account(cfg)}**", + f"- 计划 ID:**{plan_id}**", + f"- 方向:{_dir_text(cfg, direction)}|杠杆 **{int(leverage or 1)}x**", + f"- 止损:{_fmt_price(cfg, sym, stop_loss)}|止盈:{_fmt_price(cfg, sym, take_profit)}", + f"- 补仓区:{_fmt_price(cfg, sym, add_upper)}|补仓档 **{int(dca_legs or 0)}** 档", + f"- 风险:**{risk_percent}%**|首仓张数:**{first_order_amount}**", + ] + if avg_entry is not None: + lines.append(f"- 首仓成交价:{_fmt_price(cfg, sym, avg_entry)}") + if snapshot_usdt is not None: + try: + lines.append(f"- 启动时合约可用:**{round(float(snapshot_usdt), 2)} U**") + except (TypeError, ValueError): + pass + lines.append("- 说明:交易所已挂止损;止盈由程序监控;结束/保本将另行推送") + _send(cfg, "\n".join(lines)) + + +def notify_trend_plan_ended( + cfg: dict[str, Any], + *, + plan_id: int, + symbol: str, + direction: str, + end_type: str, + result_label: Optional[str] = None, + exit_price: Optional[float] = None, + pnl_amount: Optional[float] = None, + extra: Optional[str] = None, +) -> None: + sym = symbol or "—" + res = (result_label or end_type or "—").strip() + lines = [ + f"# 🏁 {sym} 趋势回调计划已结束", + f"**账户:{_account(cfg)}**", + f"- 计划 ID:**{plan_id}**", + f"- 方向:{_dir_text(cfg, direction)}", + f"- 结束方式:**{end_type}**", + f"- 结果:**{res}**", + ] + if exit_price is not None: + lines.append(f"- 离场参考价:{_fmt_price(cfg, sym, exit_price)}") + if pnl_amount is not None: + lines.append(f"- 本单盈亏:**{_fmt_pnl(pnl_amount)}**") + if extra: + lines.append(f"- {extra}") + _send(cfg, "\n".join(lines)) + + +def notify_roll_group_started( + cfg: dict[str, Any], + *, + group_id: int, + symbol: str, + direction: str, + order_monitor_id: int, + initial_take_profit: Optional[float] = None, + initial_stop_loss: Optional[float] = None, +) -> None: + sym = symbol or "—" + lines = [ + f"# 🚀 {sym} 滚仓计划已开始", + f"**账户:{_account(cfg)}**", + f"- 滚仓组 ID:**{group_id}**|绑定下单监控 **#{order_monitor_id}**", + f"- 方向:{_dir_text(cfg, direction)}", + f"- 首仓止盈(锁定):{_fmt_price(cfg, sym, initial_take_profit)}", + f"- 当前止损:{_fmt_price(cfg, sym, initial_stop_loss)}", + "- 说明:顺势加仓为人工触发;组结束(无持仓/监控结案)将另行推送", + ] + _send(cfg, "\n".join(lines)) + + +def notify_roll_group_ended( + cfg: dict[str, Any], + *, + group_id: int, + symbol: str, + direction: str, + reason: str, + leg_count: int = 0, +) -> None: + sym = symbol or "—" + lines = [ + f"# 🏁 {sym} 滚仓计划已结束", + f"**账户:{_account(cfg)}**", + f"- 滚仓组 ID:**{group_id}**", + f"- 方向:{_dir_text(cfg, direction)}", + f"- 结束原因:**{reason}**", + f"- 已完成滚仓腿数:**{int(leg_count or 0)}**", + ] + _send(cfg, "\n".join(lines)) diff --git a/strategy_templates/gate_transfer_block.html b/lib/strategy/templates/gate_transfer_block.html similarity index 100% rename from strategy_templates/gate_transfer_block.html rename to lib/strategy/templates/gate_transfer_block.html diff --git a/strategy_templates/key_focus_v2.html b/lib/strategy/templates/key_focus_v2.html similarity index 100% rename from strategy_templates/key_focus_v2.html rename to lib/strategy/templates/key_focus_v2.html diff --git a/strategy_templates/key_monitor_panel.html b/lib/strategy/templates/key_monitor_panel.html similarity index 100% rename from strategy_templates/key_monitor_panel.html rename to lib/strategy/templates/key_monitor_panel.html diff --git a/strategy_templates/key_monitor_rule_tips.html b/lib/strategy/templates/key_monitor_rule_tips.html similarity index 100% rename from strategy_templates/key_monitor_rule_tips.html rename to lib/strategy/templates/key_monitor_rule_tips.html diff --git a/strategy_templates/order_focus_v2.html b/lib/strategy/templates/order_focus_v2.html similarity index 100% rename from strategy_templates/order_focus_v2.html rename to lib/strategy/templates/order_focus_v2.html diff --git a/strategy_templates/order_monitor_rule_tips_binance.html b/lib/strategy/templates/order_monitor_rule_tips_binance.html similarity index 100% rename from strategy_templates/order_monitor_rule_tips_binance.html rename to lib/strategy/templates/order_monitor_rule_tips_binance.html diff --git a/strategy_templates/order_monitor_rule_tips_gate.html b/lib/strategy/templates/order_monitor_rule_tips_gate.html similarity index 100% rename from strategy_templates/order_monitor_rule_tips_gate.html rename to lib/strategy/templates/order_monitor_rule_tips_gate.html diff --git a/strategy_templates/order_monitor_rule_tips_gate_bot.html b/lib/strategy/templates/order_monitor_rule_tips_gate_bot.html similarity index 100% rename from strategy_templates/order_monitor_rule_tips_gate_bot.html rename to lib/strategy/templates/order_monitor_rule_tips_gate_bot.html diff --git a/strategy_templates/order_monitor_rule_tips_okx.html b/lib/strategy/templates/order_monitor_rule_tips_okx.html similarity index 100% rename from strategy_templates/order_monitor_rule_tips_okx.html rename to lib/strategy/templates/order_monitor_rule_tips_okx.html diff --git a/strategy_templates/order_plan_preview_bar.html b/lib/strategy/templates/order_plan_preview_bar.html similarity index 100% rename from strategy_templates/order_plan_preview_bar.html rename to lib/strategy/templates/order_plan_preview_bar.html diff --git a/strategy_templates/strategy_records_page.html b/lib/strategy/templates/strategy_records_page.html similarity index 100% rename from strategy_templates/strategy_records_page.html rename to lib/strategy/templates/strategy_records_page.html diff --git a/strategy_templates/strategy_roll.html b/lib/strategy/templates/strategy_roll.html similarity index 100% rename from strategy_templates/strategy_roll.html rename to lib/strategy/templates/strategy_roll.html diff --git a/strategy_templates/strategy_roll_docs.html b/lib/strategy/templates/strategy_roll_docs.html similarity index 100% rename from strategy_templates/strategy_roll_docs.html rename to lib/strategy/templates/strategy_roll_docs.html diff --git a/strategy_templates/strategy_roll_panel.html b/lib/strategy/templates/strategy_roll_panel.html similarity index 100% rename from strategy_templates/strategy_roll_panel.html rename to lib/strategy/templates/strategy_roll_panel.html diff --git a/strategy_templates/strategy_subnav.html b/lib/strategy/templates/strategy_subnav.html similarity index 100% rename from strategy_templates/strategy_subnav.html rename to lib/strategy/templates/strategy_subnav.html diff --git a/strategy_templates/strategy_trading_page.html b/lib/strategy/templates/strategy_trading_page.html similarity index 100% rename from strategy_templates/strategy_trading_page.html rename to lib/strategy/templates/strategy_trading_page.html diff --git a/strategy_templates/strategy_trend_disabled.html b/lib/strategy/templates/strategy_trend_disabled.html similarity index 100% rename from strategy_templates/strategy_trend_disabled.html rename to lib/strategy/templates/strategy_trend_disabled.html diff --git a/strategy_templates/strategy_trend_disabled_panel.html b/lib/strategy/templates/strategy_trend_disabled_panel.html similarity index 100% rename from strategy_templates/strategy_trend_disabled_panel.html rename to lib/strategy/templates/strategy_trend_disabled_panel.html diff --git a/strategy_templates/strategy_trend_panel.html b/lib/strategy/templates/strategy_trend_panel.html similarity index 100% rename from strategy_templates/strategy_trend_panel.html rename to lib/strategy/templates/strategy_trend_panel.html diff --git a/lib/trade/__init__.py b/lib/trade/__init__.py new file mode 100644 index 0000000..ab164b5 --- /dev/null +++ b/lib/trade/__init__.py @@ -0,0 +1 @@ +"""Shared library package.""" diff --git a/account_risk_lib.py b/lib/trade/account_risk_lib.py similarity index 96% rename from account_risk_lib.py rename to lib/trade/account_risk_lib.py index 3fe848a..0487f65 100644 --- a/account_risk_lib.py +++ b/lib/trade/account_risk_lib.py @@ -1,845 +1,845 @@ -"""账户冷静期 / 日冻结风控(四所实例共用)。""" -from __future__ import annotations - -import os -from datetime import datetime, timezone -from typing import Any, Callable, Optional - -STATUS_NORMAL = "normal" -STATUS_FREEZE_1H = "freeze_1h" -STATUS_FREEZE_4H = "freeze_4h" -STATUS_DAILY = "freeze_daily" -STATUS_FREEZE_POSITION = "freeze_position" - -STATUS_LABELS = { - STATUS_NORMAL: "正常", - STATUS_FREEZE_1H: "1h冻结", - STATUS_FREEZE_4H: "4h冻结", - STATUS_DAILY: "日冻结", - STATUS_FREEZE_POSITION: "仓位上限冻结", -} - -MOOD_ISSUE_OPTIONS = ( - "怕踏空", - "报复开仓", - "盈利飘了", - "拿不住单", - "扛单", - "重仓违规", -) - -# 仅以下来源计入「手动平仓」风控(用户主动点平仓/结束计划) -CLOSE_SOURCE_USER_INSTANCE = "user_instance" -CLOSE_SOURCE_USER_HUB = "user_hub" -CLOSE_SOURCE_USER_TREND_STOP = "user_trend_stop" - -USER_INITIATED_CLOSE_SOURCES = frozenset( - { - CLOSE_SOURCE_USER_INSTANCE, - CLOSE_SOURCE_USER_HUB, - CLOSE_SOURCE_USER_TREND_STOP, - } -) - - -def _env_bool(key: str, default: bool = True) -> bool: - raw = (os.getenv(key) or "").strip().lower() - if not raw: - return default - return raw in ("1", "true", "yes", "on") - - -def _env_hours(key: str, default: float) -> float: - try: - v = float(os.getenv(key, str(default))) - except (TypeError, ValueError): - v = default - return max(0.0, v) - - -def _app_tz(): - from zoneinfo import ZoneInfo - - name = (os.getenv("APP_TIMEZONE") or os.getenv("TZ") or "Asia/Shanghai").strip() - try: - return ZoneInfo(name) - except Exception: - return ZoneInfo("Asia/Shanghai") - - -def risk_control_enabled() -> bool: - return _env_bool("RISK_CONTROL_ENABLED", True) - - -def cooling_hours_manual() -> float: - return _env_hours("RISK_COOLING_HOURS_MANUAL", 4.0) - - -def cooling_hours_manual_journal() -> float: - return _env_hours("RISK_COOLING_HOURS_MANUAL_JOURNAL", 1.0) - - -def manual_close_daily_limit() -> int: - try: - return max(1, int(os.getenv("RISK_MANUAL_CLOSE_DAILY_LIMIT", "2"))) - except (TypeError, ValueError): - return 2 - - -def max_active_positions_from_env(default: int = 1) -> int: - try: - return max(1, int(os.getenv("MAX_ACTIVE_POSITIONS", str(default)))) - except (TypeError, ValueError): - return max(1, default) - - -def position_limit_reached( - conn, - *, - max_active_positions: Optional[int] = None, -) -> tuple[bool, int, int]: - """(已达上限, 计入上限的活跃数, 上限值)。""" - from strategy_trade_labels import count_position_limit_active_monitors - - mx = max(1, int(max_active_positions if max_active_positions is not None else max_active_positions_from_env())) - ac = count_position_limit_active_monitors(conn) - return ac >= mx, ac, mx - - -def mood_issues_daily_freeze_enabled() -> bool: - return _env_bool("RISK_MOOD_ISSUES_DAILY_FREEZE", True) - - -def ensure_account_risk_schema(conn) -> None: - conn.execute( - """CREATE TABLE IF NOT EXISTS account_risk_state ( - id INTEGER PRIMARY KEY CHECK (id = 1), - trading_day TEXT, - manual_close_count INTEGER DEFAULT 0, - cooloff_until_ms INTEGER, - cooloff_hours INTEGER, - daily_frozen INTEGER DEFAULT 0, - pending_journal_trade_id INTEGER, - last_close_at_ms INTEGER, - updated_at TEXT - )""" - ) - row = conn.execute("SELECT id FROM account_risk_state WHERE id=1").fetchone() - if not row: - conn.execute( - "INSERT INTO account_risk_state (id, trading_day, manual_close_count, daily_frozen) VALUES (1, '', 0, 0)" - ) - - -def _row_get(row, key, default=None): - if row is None: - return default - try: - return row[key] - except (KeyError, IndexError, TypeError): - return default - - -def _now_ms(now: Optional[datetime] = None) -> int: - dt = now or datetime.now() - if dt.tzinfo is None: - dt = dt.replace(tzinfo=_app_tz()) - return int(dt.timestamp() * 1000) - - -def _normalize_epoch_ms(ms: int, ref_now_ms: Optional[int] = None) -> int: - """修正旧版把北京时间 naive 当作 UTC 写入的 epoch 毫秒。""" - tz = _app_tz() - off = datetime.now(tz).utcoffset() - if not off: - return int(ms) - offset_ms = int(off.total_seconds() * 1000) - if offset_ms == 0: - return int(ms) - ref = int(ref_now_ms) if ref_now_ms is not None else _now_ms(datetime.now(tz)) - corrected = int(ms) - offset_ms - if abs(int(ms) - ref) <= abs(corrected - ref): - return int(ms) - return corrected - - -def _sanitize_last_close_ms(last_ms: int, now_ms: int) -> Optional[int]: - """平仓时刻须不晚于当前(允许 1 分钟时钟偏差);显著未来视为无效锚点。""" - slack_ms = 60 * 1000 - if last_ms > now_ms + slack_ms: - return None - return last_ms - - -def _cooloff_duration_ms(hours: float) -> int: - return int(max(0.0, float(hours)) * 3600 * 1000) - - -def _cooloff_hours_value(row) -> float: - return float(_row_get(row, "cooloff_hours") or cooling_hours_manual()) - - -def _resolved_cooloff_until_ms(row, now_ms: int) -> Optional[int]: - """冷静期结束 = last_close + cooloff_hours;无效/已过期锚点不再重启计时。""" - hours = _cooloff_hours_value(row) - journal_h = cooling_hours_manual_journal() - duration_ms = _cooloff_duration_ms(hours) - last_raw = _row_get(row, "last_close_at_ms") - stored_raw = _cooloff_until_ms(row) - - if last_raw is not None: - try: - last_ms = _sanitize_last_close_ms( - _normalize_epoch_ms(int(last_raw), now_ms), now_ms - ) - except (TypeError, ValueError): - last_ms = None - if last_ms is not None: - end_ms = last_ms + duration_ms - if end_ms > now_ms: - return end_ms - if hours <= journal_h + 1e-6: - return None - - if stored_raw is None: - return None - stored_ms = _normalize_epoch_ms(int(stored_raw), now_ms) - return stored_ms if stored_ms > now_ms else None - - -def _clear_inactive_cooloff( - conn, - *, - now: Optional[datetime] = None, -) -> None: - """冷静期已结束或锚点无效时清库,避免重启后误读旧冻结。""" - conn.execute( - """UPDATE account_risk_state SET - cooloff_until_ms=NULL, - cooloff_hours=NULL, - last_close_at_ms=NULL, - updated_at=? - WHERE id=1""", - ((now or datetime.now()).strftime("%Y-%m-%d %H:%M:%S"),), - ) - - -def _freeze_tier_from_remaining_ms(remaining_ms: int, hours: float) -> str: - journal_h = cooling_hours_manual_journal() - rh = remaining_ms / 3600000.0 - if rh <= journal_h + (5 / 60): - return STATUS_FREEZE_1H - return STATUS_FREEZE_4H - - -def _freeze_status_label(hours: float, status: str) -> str: - if status == STATUS_FREEZE_1H: - return STATUS_LABELS[STATUS_FREEZE_1H] - if status == STATUS_FREEZE_4H: - h = int(hours) if float(hours) == int(hours) else round(float(hours), 1) - if abs(float(hours) - 4.0) < 1e-6: - return STATUS_LABELS[STATUS_FREEZE_4H] - return f"{h}h冻结" - return STATUS_LABELS.get(status, STATUS_LABELS[STATUS_NORMAL]) - - -def _ms_to_local_str(ms: Optional[int], fmt_local: Callable[[int], str]) -> Optional[str]: - if ms is None: - return None - try: - return fmt_local(int(ms)) - except Exception: - return None - - -def _load_state(conn): - ensure_account_risk_schema(conn) - return conn.execute("SELECT * FROM account_risk_state WHERE id=1").fetchone() - - -def _sync_trading_day(conn, trading_day: str, now: Optional[datetime] = None) -> Any: - row = _load_state(conn) - td = (trading_day or "").strip() - stored = str(_row_get(row, "trading_day") or "").strip() - if stored != td: - now_ms = _now_ms(now) - cooloff_active = _resolved_cooloff_until_ms(row, now_ms) - conn.execute( - """UPDATE account_risk_state SET - trading_day=?, - manual_close_count=0, - daily_frozen=0, - cooloff_until_ms=?, - cooloff_hours=?, - last_close_at_ms=?, - pending_journal_trade_id=NULL, - updated_at=? - WHERE id=1""", - ( - td, - cooloff_active, - _row_get(row, "cooloff_hours") if cooloff_active else None, - _row_get(row, "last_close_at_ms") if cooloff_active else None, - (now or datetime.now()).strftime("%Y-%m-%d %H:%M:%S"), - ), - ) - row = _load_state(conn) - return row - - -def _set_cooloff( - conn, - *, - trading_day: str, - close_at_ms: int, - hours: float, - now: Optional[datetime] = None, -) -> None: - _sync_trading_day(conn, trading_day, now=now) - h = max(0.0, float(hours)) - until_ms = int(close_at_ms + h * 3600 * 1000) - conn.execute( - """UPDATE account_risk_state SET - cooloff_until_ms=?, - cooloff_hours=?, - last_close_at_ms=?, - updated_at=? - WHERE id=1""", - ( - until_ms, - int(h) if h == int(h) else int(round(h)), - int(close_at_ms), - (now or datetime.now()).strftime("%Y-%m-%d %H:%M:%S"), - ), - ) - - -def _set_cooloff_until( - conn, - *, - trading_day: str, - until_ms: int, - hours: float, - now: Optional[datetime] = None, -) -> None: - _sync_trading_day(conn, trading_day, now=now) - h = max(0.0, float(hours)) - conn.execute( - """UPDATE account_risk_state SET - cooloff_until_ms=?, - cooloff_hours=?, - updated_at=? - WHERE id=1""", - ( - int(until_ms), - int(h) if h == int(h) else int(round(h)), - (now or datetime.now()).strftime("%Y-%m-%d %H:%M:%S"), - ), - ) - - -def _ms_trading_day_label(ms: int) -> str: - dt = datetime.fromtimestamp(ms / 1000, tz=_app_tz()) - return dt.strftime("%Y-%m-%d") - - -def _parse_journal_close_ms(raw: Any) -> Optional[int]: - if raw is None: - return None - s = str(raw).strip() - if not s: - return None - for fmt in ("%Y-%m-%d %H:%M:%S", "%Y/%m/%d %H:%M:%S", "%Y-%m-%d %H:%M"): - try: - dt = datetime.strptime(s[:19] if len(s) > 16 else s, fmt) - return _now_ms(dt) - except ValueError: - continue - return None - - -def _latest_journaled_manual_close_ms(conn, trading_day: str) -> Optional[int]: - """当日最近一条已复盘的手动平仓时刻(journal 有说明)。""" - try: - rows = conn.execute( - """SELECT close_datetime FROM journal_entries - WHERE early_exit_trigger='手动平仓' - AND early_exit_note IS NOT NULL AND TRIM(early_exit_note) <> '' - ORDER BY close_datetime DESC""" - ).fetchall() - except Exception: - return None - td = (trading_day or "").strip() - best: Optional[int] = None - for row in rows: - ms = _parse_journal_close_ms(_row_get(row, "close_datetime")) - if ms is None: - continue - if td and _ms_trading_day_label(ms) != td: - continue - if best is None or ms > best: - best = ms - return best - - -def _journaled_manual_cooloff_expired( - conn, *, trading_day: str, now_ms: int, pending: Any -) -> bool: - """当日手动平仓已复盘且 1h 冷静期结束,且无待复盘的新平仓。""" - if pending is not None: - try: - if int(pending) != 0: - return False - except (TypeError, ValueError): - return False - close_ms = _latest_journaled_manual_close_ms(conn, trading_day) - if close_ms is None: - return False - journal_ms = _cooloff_duration_ms(cooling_hours_manual_journal()) - return close_ms + journal_ms <= now_ms - - -def _cooloff_until_ms(row) -> Optional[int]: - raw = _row_get(row, "cooloff_until_ms") - try: - return int(raw) if raw is not None else None - except (TypeError, ValueError): - return None - - -def _repair_stale_cooloff_row( - conn, - row, - *, - now_ms: int, - resolved_until_ms: Optional[int], - now: Optional[datetime] = None, -) -> None: - """脏数据读时写回:过期/无效则清库,否则对齐 until / last_close。""" - last_raw = _row_get(row, "last_close_at_ms") - stored_raw = _cooloff_until_ms(row) - if last_raw is None and stored_raw is None: - return - if resolved_until_ms is None: - if last_raw is not None or stored_raw is not None: - _clear_inactive_cooloff(conn, now=now) - return - dirty = False - new_last: Optional[int] = None - if last_raw is not None: - try: - norm = _normalize_epoch_ms(int(last_raw), now_ms) - sanitized = _sanitize_last_close_ms(norm, now_ms) - if sanitized is None: - dirty = True - else: - new_last = sanitized - if sanitized != int(last_raw): - dirty = True - except (TypeError, ValueError): - dirty = True - if stored_raw is not None: - stored_norm = _normalize_epoch_ms(int(stored_raw), now_ms) - if abs(stored_norm - int(resolved_until_ms)) > 60 * 1000: - dirty = True - if not dirty: - return - conn.execute( - """UPDATE account_risk_state SET - cooloff_until_ms=?, - cooloff_hours=?, - last_close_at_ms=?, - updated_at=? - WHERE id=1""", - ( - resolved_until_ms, - _row_get(row, "cooloff_hours"), - new_last, - (now or datetime.now()).strftime("%Y-%m-%d %H:%M:%S"), - ), - ) - - -def _journal_can_reduce_cooloff(row, pending, now_ms: int) -> bool: - if int(_row_get(row, "daily_frozen") or 0) == 1: - return False - if _resolved_cooloff_until_ms(row, now_ms) is None: - return False - journal_h = cooling_hours_manual_journal() - cooloff_h = float(_row_get(row, "cooloff_hours") or cooling_hours_manual()) - if cooloff_h <= journal_h + 1e-6: - return False - if pending is not None: - try: - if int(pending) != 0: - return True - except (TypeError, ValueError): - return True - return True - - -def _journal_cooloff_until_ms(row, now_ms: int, journal_hours: float) -> int: - journal_ms = int(max(0.0, float(journal_hours)) * 3600 * 1000) - last_close_ms = _row_get(row, "last_close_at_ms") - if last_close_ms: - try: - base_ms = _sanitize_last_close_ms( - _normalize_epoch_ms(int(last_close_ms), now_ms), now_ms - ) - except (TypeError, ValueError): - base_ms = None - if base_ms is None: - base_ms = now_ms - else: - base_ms = now_ms - until_from_close = base_ms + journal_ms - if until_from_close > now_ms: - return until_from_close - return now_ms + journal_ms - - -def _set_daily_frozen(conn, *, trading_day: str, now: Optional[datetime] = None) -> None: - _sync_trading_day(conn, trading_day, now=now) - conn.execute( - """UPDATE account_risk_state SET daily_frozen=1, updated_at=? WHERE id=1""", - ((now or datetime.now()).strftime("%Y-%m-%d %H:%M:%S"),), - ) - - -def parse_mood_issues(raw: Any) -> list[str]: - if raw is None: - return [] - if isinstance(raw, (list, tuple)): - parts = [str(x).strip() for x in raw if str(x).strip()] - else: - parts = [x.strip() for x in str(raw).split(",") if x.strip()] - return [p for p in parts if p in MOOD_ISSUE_OPTIONS] - - -def _record_one_user_initiated_close( - conn, - *, - source: str, - trade_record_id: Optional[int], - closed_at_ms: Optional[int], - trading_day: str, - now: Optional[datetime] = None, -) -> None: - row = _sync_trading_day(conn, trading_day, now=now) - count = int(_row_get(row, "manual_close_count") or 0) + 1 - close_ms = int(closed_at_ms) if closed_at_ms else _now_ms(now) - pending = int(trade_record_id) if trade_record_id else None - conn.execute( - """UPDATE account_risk_state SET - manual_close_count=?, - pending_journal_trade_id=?, - updated_at=? - WHERE id=1""", - (count, pending, (now or datetime.now()).strftime("%Y-%m-%d %H:%M:%S")), - ) - if count >= manual_close_daily_limit(): - _set_daily_frozen(conn, trading_day=trading_day, now=now) - return - _set_cooloff( - conn, - trading_day=trading_day, - close_at_ms=close_ms, - hours=cooling_hours_manual(), - now=now, - ) - - -def on_user_initiated_close( - conn, - *, - source: str, - trade_record_id: Optional[int] = None, - closed_at_ms: Optional[int] = None, - trading_day: str, - now: Optional[datetime] = None, - count: int = 1, -) -> None: - """用户主动平仓/结束趋势计划:计入手动平仓次数与冷静期。""" - if not risk_control_enabled(): - return - src = (source or "").strip() - if src not in USER_INITIATED_CLOSE_SOURCES: - return - n = max(1, int(count or 1)) - for i in range(n): - _record_one_user_initiated_close( - conn, - source=src, - trade_record_id=trade_record_id if i == 0 else None, - closed_at_ms=closed_at_ms, - trading_day=trading_day, - now=now, - ) - row = _load_state(conn) - if int(_row_get(row, "daily_frozen") or 0) == 1: - break - - -def on_manual_close( - conn, - *, - trade_record_id: int, - closed_at_ms: Optional[int], - trading_day: str, - now: Optional[datetime] = None, -) -> None: - """兼容旧调用:等同实例页用户平仓。""" - on_user_initiated_close( - conn, - source=CLOSE_SOURCE_USER_INSTANCE, - trade_record_id=trade_record_id, - closed_at_ms=closed_at_ms, - trading_day=trading_day, - now=now, - count=1, - ) - - -def on_journal_saved( - conn, - *, - early_exit_trigger: str, - early_exit_note: str, - mood_issues_raw: Any, - trading_day: str, - now: Optional[datetime] = None, -) -> None: - if not risk_control_enabled(): - return - row = _sync_trading_day(conn, trading_day, now=now) - mood_list = parse_mood_issues(mood_issues_raw) - if mood_issues_daily_freeze_enabled() and mood_list: - _set_daily_frozen(conn, trading_day=trading_day, now=now) - conn.execute( - "UPDATE account_risk_state SET pending_journal_trade_id=NULL, updated_at=? WHERE id=1", - ((now or datetime.now()).strftime("%Y-%m-%d %H:%M:%S"),), - ) - return - pending = _row_get(row, "pending_journal_trade_id") - trigger = (early_exit_trigger or "").strip() - note = (early_exit_note or "").strip() - now_ms = _now_ms(now) - if ( - trigger == "手动平仓" - and note - and int(_row_get(row, "daily_frozen") or 0) != 1 - and _journal_can_reduce_cooloff(row, pending, now_ms) - ): - journal_h = cooling_hours_manual_journal() - until_ms = _journal_cooloff_until_ms(row, now_ms, journal_h) - _set_cooloff_until( - conn, - trading_day=trading_day, - until_ms=until_ms, - hours=journal_h, - now=now, - ) - anchor_ms = until_ms - int(journal_h * 3600 * 1000) - conn.execute( - """UPDATE account_risk_state SET - pending_journal_trade_id=NULL, - last_close_at_ms=?, - updated_at=? - WHERE id=1""", - (int(anchor_ms), (now or datetime.now()).strftime("%Y-%m-%d %H:%M:%S")), - ) - return - - -def apply_manual_close_journal_cooloff( - conn, - *, - early_exit_note: str, - trading_day: str, - now: Optional[datetime] = None, -) -> None: - """核对修改或复盘:手动平仓 + 说明后尝试将 4h 冷静期降为 1h。""" - note = (early_exit_note or "").strip() - if not note: - return - on_journal_saved( - conn, - early_exit_trigger="手动平仓", - early_exit_note=note, - mood_issues_raw="", - trading_day=trading_day, - now=now, - ) - - -def _next_trading_day_reset_ms(now: datetime, reset_hour: int) -> int: - from datetime import timedelta - - h = max(0, min(23, int(reset_hour))) - candidate = now.replace(hour=h, minute=0, second=0, microsecond=0) - if now >= candidate: - candidate = candidate + timedelta(days=1) - return _now_ms(candidate) - - -def enrich_risk_status_countdown( - st: dict[str, Any], - *, - now: Optional[datetime] = None, - daily_reset_hour: int = 8, -) -> dict[str, Any]: - """补充 freeze_until_ms / freeze_remaining_sec,供前端倒计时展示。""" - if not st.get("enabled", True): - return st - dt = now or datetime.now() - now_ms = _now_ms(dt) - until_ms: Optional[int] = None - if st.get("daily_frozen"): - until_ms = _next_trading_day_reset_ms(dt, daily_reset_hour) - elif st.get("cooloff_until_ms"): - try: - until_ms = int(st["cooloff_until_ms"]) - except (TypeError, ValueError): - until_ms = None - if until_ms is not None and until_ms > now_ms: - st["freeze_until_ms"] = until_ms - st["freeze_remaining_sec"] = max(0, (until_ms - now_ms) // 1000) - else: - st["freeze_until_ms"] = None - st["freeze_remaining_sec"] = 0 - return st - - -def apply_position_limit_risk( - st: dict[str, Any], - active_count: int, - *, - max_active_positions: Optional[int] = None, -) -> dict[str, Any]: - """持仓达 env MAX_ACTIVE_POSITIONS 时叠加「仓位上限冻结」(时间冻结优先展示)。""" - out = dict(st or {}) - try: - mx = max(1, int(max_active_positions if max_active_positions is not None else max_active_positions_from_env())) - except (TypeError, ValueError): - mx = max_active_positions_from_env() - try: - ac = max(0, int(active_count)) - except (TypeError, ValueError): - ac = 0 - out["max_active_positions"] = mx - out["active_count"] = ac - if out.get("status") != STATUS_NORMAL: - return out - if ac >= mx: - out["status"] = STATUS_FREEZE_POSITION - out["status_label"] = STATUS_LABELS[STATUS_FREEZE_POSITION] - out["can_trade"] = False - out["can_roll"] = True - out["reason"] = f"已达最大持仓数({ac}/{mx}),新开仓已冻结,顺势加仓仍可用" - out["position_limit_frozen"] = True - out["freeze_until_ms"] = None - out["freeze_remaining_sec"] = 0 - else: - out["position_limit_frozen"] = False - out["can_roll"] = True - return out - - -def compute_account_risk_status( - conn, - *, - trading_day: str, - now: Optional[datetime] = None, - fmt_local_ms: Optional[Callable[[int], str]] = None, -) -> dict[str, Any]: - if not risk_control_enabled(): - return { - "enabled": False, - "status": STATUS_NORMAL, - "status_label": STATUS_LABELS[STATUS_NORMAL], - "can_trade": True, - "reason": "", - "cooloff_until_ms": None, - "cooloff_until": None, - "manual_close_count": 0, - "daily_frozen": False, - } - row = _sync_trading_day(conn, trading_day, now=now) - now_ms = _now_ms(now) - daily_frozen = int(_row_get(row, "daily_frozen") or 0) == 1 - pending = _row_get(row, "pending_journal_trade_id") - cooloff_until_ms = _resolved_cooloff_until_ms(row, now_ms) - if ( - not daily_frozen - and cooloff_until_ms is not None - and _journaled_manual_cooloff_expired( - conn, trading_day=trading_day, now_ms=now_ms, pending=pending - ) - ): - cooloff_until_ms = None - if not daily_frozen: - _repair_stale_cooloff_row( - conn, row, now_ms=now_ms, resolved_until_ms=cooloff_until_ms, now=now - ) - row = _load_state(conn) - cooloff_until_ms = _resolved_cooloff_until_ms(row, now_ms) - manual_close_count = int(_row_get(row, "manual_close_count") or 0) - - status = STATUS_NORMAL - reason = "" - if daily_frozen: - status = STATUS_DAILY - reason = f"账户今日已冻结(手动平仓 {manual_close_count} 次或复盘情绪标签)" - elif cooloff_until_ms is not None: - remaining_ms = cooloff_until_ms - now_ms - hours = _cooloff_hours_value(row) - status = _freeze_tier_from_remaining_ms(remaining_ms, hours) - status_label = _freeze_status_label(hours, status) - until_str = _ms_to_local_str(cooloff_until_ms, fmt_local_ms) if fmt_local_ms else None - label = status_label - reason = f"账户{label}中" - if until_str: - reason += f",至 {until_str}" - - can_trade = status == STATUS_NORMAL - freeze_remaining_sec = ( - max(0, (cooloff_until_ms - now_ms) // 1000) if cooloff_until_ms is not None else 0 - ) - return { - "enabled": True, - "status": status, - "status_label": _freeze_status_label(_cooloff_hours_value(row), status) - if status in (STATUS_FREEZE_1H, STATUS_FREEZE_4H) - else STATUS_LABELS[status], - "can_trade": can_trade, - "reason": reason, - "cooloff_until_ms": cooloff_until_ms, - "cooloff_until": _ms_to_local_str(cooloff_until_ms, fmt_local_ms) - if fmt_local_ms and cooloff_until_ms - else None, - "manual_close_count": manual_close_count, - "daily_frozen": daily_frozen, - "pending_journal_trade_id": pending, - "freeze_remaining_sec": freeze_remaining_sec if not can_trade else 0, - } - - -def account_risk_blocks_trading( - conn, - *, - trading_day: str, - now: Optional[datetime] = None, - fmt_local_ms: Optional[Callable[[int], str]] = None, -) -> tuple[bool, str]: - """返回 (允许交易, 拒绝原因)。""" - st = compute_account_risk_status( - conn, trading_day=trading_day, now=now, fmt_local_ms=fmt_local_ms - ) - if st.get("can_trade"): - return True, "" - return False, str(st.get("reason") or STATUS_LABELS.get(st.get("status"), "账户冻结")) - - -def insert_trade_record_id(conn) -> int: - row = conn.execute("SELECT last_insert_rowid()").fetchone() - return int(row[0] if row else 0) +"""账户冷静期 / 日冻结风控(四所实例共用)。""" +from __future__ import annotations + +import os +from datetime import datetime, timezone +from typing import Any, Callable, Optional + +STATUS_NORMAL = "normal" +STATUS_FREEZE_1H = "freeze_1h" +STATUS_FREEZE_4H = "freeze_4h" +STATUS_DAILY = "freeze_daily" +STATUS_FREEZE_POSITION = "freeze_position" + +STATUS_LABELS = { + STATUS_NORMAL: "正常", + STATUS_FREEZE_1H: "1h冻结", + STATUS_FREEZE_4H: "4h冻结", + STATUS_DAILY: "日冻结", + STATUS_FREEZE_POSITION: "仓位上限冻结", +} + +MOOD_ISSUE_OPTIONS = ( + "怕踏空", + "报复开仓", + "盈利飘了", + "拿不住单", + "扛单", + "重仓违规", +) + +# 仅以下来源计入「手动平仓」风控(用户主动点平仓/结束计划) +CLOSE_SOURCE_USER_INSTANCE = "user_instance" +CLOSE_SOURCE_USER_HUB = "user_hub" +CLOSE_SOURCE_USER_TREND_STOP = "user_trend_stop" + +USER_INITIATED_CLOSE_SOURCES = frozenset( + { + CLOSE_SOURCE_USER_INSTANCE, + CLOSE_SOURCE_USER_HUB, + CLOSE_SOURCE_USER_TREND_STOP, + } +) + + +def _env_bool(key: str, default: bool = True) -> bool: + raw = (os.getenv(key) or "").strip().lower() + if not raw: + return default + return raw in ("1", "true", "yes", "on") + + +def _env_hours(key: str, default: float) -> float: + try: + v = float(os.getenv(key, str(default))) + except (TypeError, ValueError): + v = default + return max(0.0, v) + + +def _app_tz(): + from zoneinfo import ZoneInfo + + name = (os.getenv("APP_TIMEZONE") or os.getenv("TZ") or "Asia/Shanghai").strip() + try: + return ZoneInfo(name) + except Exception: + return ZoneInfo("Asia/Shanghai") + + +def risk_control_enabled() -> bool: + return _env_bool("RISK_CONTROL_ENABLED", True) + + +def cooling_hours_manual() -> float: + return _env_hours("RISK_COOLING_HOURS_MANUAL", 4.0) + + +def cooling_hours_manual_journal() -> float: + return _env_hours("RISK_COOLING_HOURS_MANUAL_JOURNAL", 1.0) + + +def manual_close_daily_limit() -> int: + try: + return max(1, int(os.getenv("RISK_MANUAL_CLOSE_DAILY_LIMIT", "2"))) + except (TypeError, ValueError): + return 2 + + +def max_active_positions_from_env(default: int = 1) -> int: + try: + return max(1, int(os.getenv("MAX_ACTIVE_POSITIONS", str(default)))) + except (TypeError, ValueError): + return max(1, default) + + +def position_limit_reached( + conn, + *, + max_active_positions: Optional[int] = None, +) -> tuple[bool, int, int]: + """(已达上限, 计入上限的活跃数, 上限值)。""" + from lib.strategy.strategy_trade_labels import count_position_limit_active_monitors + + mx = max(1, int(max_active_positions if max_active_positions is not None else max_active_positions_from_env())) + ac = count_position_limit_active_monitors(conn) + return ac >= mx, ac, mx + + +def mood_issues_daily_freeze_enabled() -> bool: + return _env_bool("RISK_MOOD_ISSUES_DAILY_FREEZE", True) + + +def ensure_account_risk_schema(conn) -> None: + conn.execute( + """CREATE TABLE IF NOT EXISTS account_risk_state ( + id INTEGER PRIMARY KEY CHECK (id = 1), + trading_day TEXT, + manual_close_count INTEGER DEFAULT 0, + cooloff_until_ms INTEGER, + cooloff_hours INTEGER, + daily_frozen INTEGER DEFAULT 0, + pending_journal_trade_id INTEGER, + last_close_at_ms INTEGER, + updated_at TEXT + )""" + ) + row = conn.execute("SELECT id FROM account_risk_state WHERE id=1").fetchone() + if not row: + conn.execute( + "INSERT INTO account_risk_state (id, trading_day, manual_close_count, daily_frozen) VALUES (1, '', 0, 0)" + ) + + +def _row_get(row, key, default=None): + if row is None: + return default + try: + return row[key] + except (KeyError, IndexError, TypeError): + return default + + +def _now_ms(now: Optional[datetime] = None) -> int: + dt = now or datetime.now() + if dt.tzinfo is None: + dt = dt.replace(tzinfo=_app_tz()) + return int(dt.timestamp() * 1000) + + +def _normalize_epoch_ms(ms: int, ref_now_ms: Optional[int] = None) -> int: + """修正旧版把北京时间 naive 当作 UTC 写入的 epoch 毫秒。""" + tz = _app_tz() + off = datetime.now(tz).utcoffset() + if not off: + return int(ms) + offset_ms = int(off.total_seconds() * 1000) + if offset_ms == 0: + return int(ms) + ref = int(ref_now_ms) if ref_now_ms is not None else _now_ms(datetime.now(tz)) + corrected = int(ms) - offset_ms + if abs(int(ms) - ref) <= abs(corrected - ref): + return int(ms) + return corrected + + +def _sanitize_last_close_ms(last_ms: int, now_ms: int) -> Optional[int]: + """平仓时刻须不晚于当前(允许 1 分钟时钟偏差);显著未来视为无效锚点。""" + slack_ms = 60 * 1000 + if last_ms > now_ms + slack_ms: + return None + return last_ms + + +def _cooloff_duration_ms(hours: float) -> int: + return int(max(0.0, float(hours)) * 3600 * 1000) + + +def _cooloff_hours_value(row) -> float: + return float(_row_get(row, "cooloff_hours") or cooling_hours_manual()) + + +def _resolved_cooloff_until_ms(row, now_ms: int) -> Optional[int]: + """冷静期结束 = last_close + cooloff_hours;无效/已过期锚点不再重启计时。""" + hours = _cooloff_hours_value(row) + journal_h = cooling_hours_manual_journal() + duration_ms = _cooloff_duration_ms(hours) + last_raw = _row_get(row, "last_close_at_ms") + stored_raw = _cooloff_until_ms(row) + + if last_raw is not None: + try: + last_ms = _sanitize_last_close_ms( + _normalize_epoch_ms(int(last_raw), now_ms), now_ms + ) + except (TypeError, ValueError): + last_ms = None + if last_ms is not None: + end_ms = last_ms + duration_ms + if end_ms > now_ms: + return end_ms + if hours <= journal_h + 1e-6: + return None + + if stored_raw is None: + return None + stored_ms = _normalize_epoch_ms(int(stored_raw), now_ms) + return stored_ms if stored_ms > now_ms else None + + +def _clear_inactive_cooloff( + conn, + *, + now: Optional[datetime] = None, +) -> None: + """冷静期已结束或锚点无效时清库,避免重启后误读旧冻结。""" + conn.execute( + """UPDATE account_risk_state SET + cooloff_until_ms=NULL, + cooloff_hours=NULL, + last_close_at_ms=NULL, + updated_at=? + WHERE id=1""", + ((now or datetime.now()).strftime("%Y-%m-%d %H:%M:%S"),), + ) + + +def _freeze_tier_from_remaining_ms(remaining_ms: int, hours: float) -> str: + journal_h = cooling_hours_manual_journal() + rh = remaining_ms / 3600000.0 + if rh <= journal_h + (5 / 60): + return STATUS_FREEZE_1H + return STATUS_FREEZE_4H + + +def _freeze_status_label(hours: float, status: str) -> str: + if status == STATUS_FREEZE_1H: + return STATUS_LABELS[STATUS_FREEZE_1H] + if status == STATUS_FREEZE_4H: + h = int(hours) if float(hours) == int(hours) else round(float(hours), 1) + if abs(float(hours) - 4.0) < 1e-6: + return STATUS_LABELS[STATUS_FREEZE_4H] + return f"{h}h冻结" + return STATUS_LABELS.get(status, STATUS_LABELS[STATUS_NORMAL]) + + +def _ms_to_local_str(ms: Optional[int], fmt_local: Callable[[int], str]) -> Optional[str]: + if ms is None: + return None + try: + return fmt_local(int(ms)) + except Exception: + return None + + +def _load_state(conn): + ensure_account_risk_schema(conn) + return conn.execute("SELECT * FROM account_risk_state WHERE id=1").fetchone() + + +def _sync_trading_day(conn, trading_day: str, now: Optional[datetime] = None) -> Any: + row = _load_state(conn) + td = (trading_day or "").strip() + stored = str(_row_get(row, "trading_day") or "").strip() + if stored != td: + now_ms = _now_ms(now) + cooloff_active = _resolved_cooloff_until_ms(row, now_ms) + conn.execute( + """UPDATE account_risk_state SET + trading_day=?, + manual_close_count=0, + daily_frozen=0, + cooloff_until_ms=?, + cooloff_hours=?, + last_close_at_ms=?, + pending_journal_trade_id=NULL, + updated_at=? + WHERE id=1""", + ( + td, + cooloff_active, + _row_get(row, "cooloff_hours") if cooloff_active else None, + _row_get(row, "last_close_at_ms") if cooloff_active else None, + (now or datetime.now()).strftime("%Y-%m-%d %H:%M:%S"), + ), + ) + row = _load_state(conn) + return row + + +def _set_cooloff( + conn, + *, + trading_day: str, + close_at_ms: int, + hours: float, + now: Optional[datetime] = None, +) -> None: + _sync_trading_day(conn, trading_day, now=now) + h = max(0.0, float(hours)) + until_ms = int(close_at_ms + h * 3600 * 1000) + conn.execute( + """UPDATE account_risk_state SET + cooloff_until_ms=?, + cooloff_hours=?, + last_close_at_ms=?, + updated_at=? + WHERE id=1""", + ( + until_ms, + int(h) if h == int(h) else int(round(h)), + int(close_at_ms), + (now or datetime.now()).strftime("%Y-%m-%d %H:%M:%S"), + ), + ) + + +def _set_cooloff_until( + conn, + *, + trading_day: str, + until_ms: int, + hours: float, + now: Optional[datetime] = None, +) -> None: + _sync_trading_day(conn, trading_day, now=now) + h = max(0.0, float(hours)) + conn.execute( + """UPDATE account_risk_state SET + cooloff_until_ms=?, + cooloff_hours=?, + updated_at=? + WHERE id=1""", + ( + int(until_ms), + int(h) if h == int(h) else int(round(h)), + (now or datetime.now()).strftime("%Y-%m-%d %H:%M:%S"), + ), + ) + + +def _ms_trading_day_label(ms: int) -> str: + dt = datetime.fromtimestamp(ms / 1000, tz=_app_tz()) + return dt.strftime("%Y-%m-%d") + + +def _parse_journal_close_ms(raw: Any) -> Optional[int]: + if raw is None: + return None + s = str(raw).strip() + if not s: + return None + for fmt in ("%Y-%m-%d %H:%M:%S", "%Y/%m/%d %H:%M:%S", "%Y-%m-%d %H:%M"): + try: + dt = datetime.strptime(s[:19] if len(s) > 16 else s, fmt) + return _now_ms(dt) + except ValueError: + continue + return None + + +def _latest_journaled_manual_close_ms(conn, trading_day: str) -> Optional[int]: + """当日最近一条已复盘的手动平仓时刻(journal 有说明)。""" + try: + rows = conn.execute( + """SELECT close_datetime FROM journal_entries + WHERE early_exit_trigger='手动平仓' + AND early_exit_note IS NOT NULL AND TRIM(early_exit_note) <> '' + ORDER BY close_datetime DESC""" + ).fetchall() + except Exception: + return None + td = (trading_day or "").strip() + best: Optional[int] = None + for row in rows: + ms = _parse_journal_close_ms(_row_get(row, "close_datetime")) + if ms is None: + continue + if td and _ms_trading_day_label(ms) != td: + continue + if best is None or ms > best: + best = ms + return best + + +def _journaled_manual_cooloff_expired( + conn, *, trading_day: str, now_ms: int, pending: Any +) -> bool: + """当日手动平仓已复盘且 1h 冷静期结束,且无待复盘的新平仓。""" + if pending is not None: + try: + if int(pending) != 0: + return False + except (TypeError, ValueError): + return False + close_ms = _latest_journaled_manual_close_ms(conn, trading_day) + if close_ms is None: + return False + journal_ms = _cooloff_duration_ms(cooling_hours_manual_journal()) + return close_ms + journal_ms <= now_ms + + +def _cooloff_until_ms(row) -> Optional[int]: + raw = _row_get(row, "cooloff_until_ms") + try: + return int(raw) if raw is not None else None + except (TypeError, ValueError): + return None + + +def _repair_stale_cooloff_row( + conn, + row, + *, + now_ms: int, + resolved_until_ms: Optional[int], + now: Optional[datetime] = None, +) -> None: + """脏数据读时写回:过期/无效则清库,否则对齐 until / last_close。""" + last_raw = _row_get(row, "last_close_at_ms") + stored_raw = _cooloff_until_ms(row) + if last_raw is None and stored_raw is None: + return + if resolved_until_ms is None: + if last_raw is not None or stored_raw is not None: + _clear_inactive_cooloff(conn, now=now) + return + dirty = False + new_last: Optional[int] = None + if last_raw is not None: + try: + norm = _normalize_epoch_ms(int(last_raw), now_ms) + sanitized = _sanitize_last_close_ms(norm, now_ms) + if sanitized is None: + dirty = True + else: + new_last = sanitized + if sanitized != int(last_raw): + dirty = True + except (TypeError, ValueError): + dirty = True + if stored_raw is not None: + stored_norm = _normalize_epoch_ms(int(stored_raw), now_ms) + if abs(stored_norm - int(resolved_until_ms)) > 60 * 1000: + dirty = True + if not dirty: + return + conn.execute( + """UPDATE account_risk_state SET + cooloff_until_ms=?, + cooloff_hours=?, + last_close_at_ms=?, + updated_at=? + WHERE id=1""", + ( + resolved_until_ms, + _row_get(row, "cooloff_hours"), + new_last, + (now or datetime.now()).strftime("%Y-%m-%d %H:%M:%S"), + ), + ) + + +def _journal_can_reduce_cooloff(row, pending, now_ms: int) -> bool: + if int(_row_get(row, "daily_frozen") or 0) == 1: + return False + if _resolved_cooloff_until_ms(row, now_ms) is None: + return False + journal_h = cooling_hours_manual_journal() + cooloff_h = float(_row_get(row, "cooloff_hours") or cooling_hours_manual()) + if cooloff_h <= journal_h + 1e-6: + return False + if pending is not None: + try: + if int(pending) != 0: + return True + except (TypeError, ValueError): + return True + return True + + +def _journal_cooloff_until_ms(row, now_ms: int, journal_hours: float) -> int: + journal_ms = int(max(0.0, float(journal_hours)) * 3600 * 1000) + last_close_ms = _row_get(row, "last_close_at_ms") + if last_close_ms: + try: + base_ms = _sanitize_last_close_ms( + _normalize_epoch_ms(int(last_close_ms), now_ms), now_ms + ) + except (TypeError, ValueError): + base_ms = None + if base_ms is None: + base_ms = now_ms + else: + base_ms = now_ms + until_from_close = base_ms + journal_ms + if until_from_close > now_ms: + return until_from_close + return now_ms + journal_ms + + +def _set_daily_frozen(conn, *, trading_day: str, now: Optional[datetime] = None) -> None: + _sync_trading_day(conn, trading_day, now=now) + conn.execute( + """UPDATE account_risk_state SET daily_frozen=1, updated_at=? WHERE id=1""", + ((now or datetime.now()).strftime("%Y-%m-%d %H:%M:%S"),), + ) + + +def parse_mood_issues(raw: Any) -> list[str]: + if raw is None: + return [] + if isinstance(raw, (list, tuple)): + parts = [str(x).strip() for x in raw if str(x).strip()] + else: + parts = [x.strip() for x in str(raw).split(",") if x.strip()] + return [p for p in parts if p in MOOD_ISSUE_OPTIONS] + + +def _record_one_user_initiated_close( + conn, + *, + source: str, + trade_record_id: Optional[int], + closed_at_ms: Optional[int], + trading_day: str, + now: Optional[datetime] = None, +) -> None: + row = _sync_trading_day(conn, trading_day, now=now) + count = int(_row_get(row, "manual_close_count") or 0) + 1 + close_ms = int(closed_at_ms) if closed_at_ms else _now_ms(now) + pending = int(trade_record_id) if trade_record_id else None + conn.execute( + """UPDATE account_risk_state SET + manual_close_count=?, + pending_journal_trade_id=?, + updated_at=? + WHERE id=1""", + (count, pending, (now or datetime.now()).strftime("%Y-%m-%d %H:%M:%S")), + ) + if count >= manual_close_daily_limit(): + _set_daily_frozen(conn, trading_day=trading_day, now=now) + return + _set_cooloff( + conn, + trading_day=trading_day, + close_at_ms=close_ms, + hours=cooling_hours_manual(), + now=now, + ) + + +def on_user_initiated_close( + conn, + *, + source: str, + trade_record_id: Optional[int] = None, + closed_at_ms: Optional[int] = None, + trading_day: str, + now: Optional[datetime] = None, + count: int = 1, +) -> None: + """用户主动平仓/结束趋势计划:计入手动平仓次数与冷静期。""" + if not risk_control_enabled(): + return + src = (source or "").strip() + if src not in USER_INITIATED_CLOSE_SOURCES: + return + n = max(1, int(count or 1)) + for i in range(n): + _record_one_user_initiated_close( + conn, + source=src, + trade_record_id=trade_record_id if i == 0 else None, + closed_at_ms=closed_at_ms, + trading_day=trading_day, + now=now, + ) + row = _load_state(conn) + if int(_row_get(row, "daily_frozen") or 0) == 1: + break + + +def on_manual_close( + conn, + *, + trade_record_id: int, + closed_at_ms: Optional[int], + trading_day: str, + now: Optional[datetime] = None, +) -> None: + """兼容旧调用:等同实例页用户平仓。""" + on_user_initiated_close( + conn, + source=CLOSE_SOURCE_USER_INSTANCE, + trade_record_id=trade_record_id, + closed_at_ms=closed_at_ms, + trading_day=trading_day, + now=now, + count=1, + ) + + +def on_journal_saved( + conn, + *, + early_exit_trigger: str, + early_exit_note: str, + mood_issues_raw: Any, + trading_day: str, + now: Optional[datetime] = None, +) -> None: + if not risk_control_enabled(): + return + row = _sync_trading_day(conn, trading_day, now=now) + mood_list = parse_mood_issues(mood_issues_raw) + if mood_issues_daily_freeze_enabled() and mood_list: + _set_daily_frozen(conn, trading_day=trading_day, now=now) + conn.execute( + "UPDATE account_risk_state SET pending_journal_trade_id=NULL, updated_at=? WHERE id=1", + ((now or datetime.now()).strftime("%Y-%m-%d %H:%M:%S"),), + ) + return + pending = _row_get(row, "pending_journal_trade_id") + trigger = (early_exit_trigger or "").strip() + note = (early_exit_note or "").strip() + now_ms = _now_ms(now) + if ( + trigger == "手动平仓" + and note + and int(_row_get(row, "daily_frozen") or 0) != 1 + and _journal_can_reduce_cooloff(row, pending, now_ms) + ): + journal_h = cooling_hours_manual_journal() + until_ms = _journal_cooloff_until_ms(row, now_ms, journal_h) + _set_cooloff_until( + conn, + trading_day=trading_day, + until_ms=until_ms, + hours=journal_h, + now=now, + ) + anchor_ms = until_ms - int(journal_h * 3600 * 1000) + conn.execute( + """UPDATE account_risk_state SET + pending_journal_trade_id=NULL, + last_close_at_ms=?, + updated_at=? + WHERE id=1""", + (int(anchor_ms), (now or datetime.now()).strftime("%Y-%m-%d %H:%M:%S")), + ) + return + + +def apply_manual_close_journal_cooloff( + conn, + *, + early_exit_note: str, + trading_day: str, + now: Optional[datetime] = None, +) -> None: + """核对修改或复盘:手动平仓 + 说明后尝试将 4h 冷静期降为 1h。""" + note = (early_exit_note or "").strip() + if not note: + return + on_journal_saved( + conn, + early_exit_trigger="手动平仓", + early_exit_note=note, + mood_issues_raw="", + trading_day=trading_day, + now=now, + ) + + +def _next_trading_day_reset_ms(now: datetime, reset_hour: int) -> int: + from datetime import timedelta + + h = max(0, min(23, int(reset_hour))) + candidate = now.replace(hour=h, minute=0, second=0, microsecond=0) + if now >= candidate: + candidate = candidate + timedelta(days=1) + return _now_ms(candidate) + + +def enrich_risk_status_countdown( + st: dict[str, Any], + *, + now: Optional[datetime] = None, + daily_reset_hour: int = 8, +) -> dict[str, Any]: + """补充 freeze_until_ms / freeze_remaining_sec,供前端倒计时展示。""" + if not st.get("enabled", True): + return st + dt = now or datetime.now() + now_ms = _now_ms(dt) + until_ms: Optional[int] = None + if st.get("daily_frozen"): + until_ms = _next_trading_day_reset_ms(dt, daily_reset_hour) + elif st.get("cooloff_until_ms"): + try: + until_ms = int(st["cooloff_until_ms"]) + except (TypeError, ValueError): + until_ms = None + if until_ms is not None and until_ms > now_ms: + st["freeze_until_ms"] = until_ms + st["freeze_remaining_sec"] = max(0, (until_ms - now_ms) // 1000) + else: + st["freeze_until_ms"] = None + st["freeze_remaining_sec"] = 0 + return st + + +def apply_position_limit_risk( + st: dict[str, Any], + active_count: int, + *, + max_active_positions: Optional[int] = None, +) -> dict[str, Any]: + """持仓达 env MAX_ACTIVE_POSITIONS 时叠加「仓位上限冻结」(时间冻结优先展示)。""" + out = dict(st or {}) + try: + mx = max(1, int(max_active_positions if max_active_positions is not None else max_active_positions_from_env())) + except (TypeError, ValueError): + mx = max_active_positions_from_env() + try: + ac = max(0, int(active_count)) + except (TypeError, ValueError): + ac = 0 + out["max_active_positions"] = mx + out["active_count"] = ac + if out.get("status") != STATUS_NORMAL: + return out + if ac >= mx: + out["status"] = STATUS_FREEZE_POSITION + out["status_label"] = STATUS_LABELS[STATUS_FREEZE_POSITION] + out["can_trade"] = False + out["can_roll"] = True + out["reason"] = f"已达最大持仓数({ac}/{mx}),新开仓已冻结,顺势加仓仍可用" + out["position_limit_frozen"] = True + out["freeze_until_ms"] = None + out["freeze_remaining_sec"] = 0 + else: + out["position_limit_frozen"] = False + out["can_roll"] = True + return out + + +def compute_account_risk_status( + conn, + *, + trading_day: str, + now: Optional[datetime] = None, + fmt_local_ms: Optional[Callable[[int], str]] = None, +) -> dict[str, Any]: + if not risk_control_enabled(): + return { + "enabled": False, + "status": STATUS_NORMAL, + "status_label": STATUS_LABELS[STATUS_NORMAL], + "can_trade": True, + "reason": "", + "cooloff_until_ms": None, + "cooloff_until": None, + "manual_close_count": 0, + "daily_frozen": False, + } + row = _sync_trading_day(conn, trading_day, now=now) + now_ms = _now_ms(now) + daily_frozen = int(_row_get(row, "daily_frozen") or 0) == 1 + pending = _row_get(row, "pending_journal_trade_id") + cooloff_until_ms = _resolved_cooloff_until_ms(row, now_ms) + if ( + not daily_frozen + and cooloff_until_ms is not None + and _journaled_manual_cooloff_expired( + conn, trading_day=trading_day, now_ms=now_ms, pending=pending + ) + ): + cooloff_until_ms = None + if not daily_frozen: + _repair_stale_cooloff_row( + conn, row, now_ms=now_ms, resolved_until_ms=cooloff_until_ms, now=now + ) + row = _load_state(conn) + cooloff_until_ms = _resolved_cooloff_until_ms(row, now_ms) + manual_close_count = int(_row_get(row, "manual_close_count") or 0) + + status = STATUS_NORMAL + reason = "" + if daily_frozen: + status = STATUS_DAILY + reason = f"账户今日已冻结(手动平仓 {manual_close_count} 次或复盘情绪标签)" + elif cooloff_until_ms is not None: + remaining_ms = cooloff_until_ms - now_ms + hours = _cooloff_hours_value(row) + status = _freeze_tier_from_remaining_ms(remaining_ms, hours) + status_label = _freeze_status_label(hours, status) + until_str = _ms_to_local_str(cooloff_until_ms, fmt_local_ms) if fmt_local_ms else None + label = status_label + reason = f"账户{label}中" + if until_str: + reason += f",至 {until_str}" + + can_trade = status == STATUS_NORMAL + freeze_remaining_sec = ( + max(0, (cooloff_until_ms - now_ms) // 1000) if cooloff_until_ms is not None else 0 + ) + return { + "enabled": True, + "status": status, + "status_label": _freeze_status_label(_cooloff_hours_value(row), status) + if status in (STATUS_FREEZE_1H, STATUS_FREEZE_4H) + else STATUS_LABELS[status], + "can_trade": can_trade, + "reason": reason, + "cooloff_until_ms": cooloff_until_ms, + "cooloff_until": _ms_to_local_str(cooloff_until_ms, fmt_local_ms) + if fmt_local_ms and cooloff_until_ms + else None, + "manual_close_count": manual_close_count, + "daily_frozen": daily_frozen, + "pending_journal_trade_id": pending, + "freeze_remaining_sec": freeze_remaining_sec if not can_trade else 0, + } + + +def account_risk_blocks_trading( + conn, + *, + trading_day: str, + now: Optional[datetime] = None, + fmt_local_ms: Optional[Callable[[int], str]] = None, +) -> tuple[bool, str]: + """返回 (允许交易, 拒绝原因)。""" + st = compute_account_risk_status( + conn, trading_day=trading_day, now=now, fmt_local_ms=fmt_local_ms + ) + if st.get("can_trade"): + return True, "" + return False, str(st.get("reason") or STATUS_LABELS.get(st.get("status"), "账户冻结")) + + +def insert_trade_record_id(conn) -> int: + row = conn.execute("SELECT last_insert_rowid()").fetchone() + return int(row[0] if row else 0) diff --git a/daily_open_limit_lib.py b/lib/trade/daily_open_limit_lib.py similarity index 100% rename from daily_open_limit_lib.py rename to lib/trade/daily_open_limit_lib.py diff --git a/manual_sltp_lib.py b/lib/trade/manual_sltp_lib.py similarity index 100% rename from manual_sltp_lib.py rename to lib/trade/manual_sltp_lib.py diff --git a/order_monitor_display_lib.py b/lib/trade/order_monitor_display_lib.py similarity index 100% rename from order_monitor_display_lib.py rename to lib/trade/order_monitor_display_lib.py diff --git a/position_sizing_lib.py b/lib/trade/position_sizing_lib.py similarity index 100% rename from position_sizing_lib.py rename to lib/trade/position_sizing_lib.py diff --git a/time_close_lib.py b/lib/trade/time_close_lib.py similarity index 100% rename from time_close_lib.py rename to lib/trade/time_close_lib.py diff --git a/trade_exchange_stats_lib.py b/lib/trade/trade_exchange_stats_lib.py similarity index 100% rename from trade_exchange_stats_lib.py rename to lib/trade/trade_exchange_stats_lib.py diff --git a/trade_result_lib.py b/lib/trade/trade_result_lib.py similarity index 100% rename from trade_result_lib.py rename to lib/trade/trade_result_lib.py diff --git a/trade_stats_calendar_lib.py b/lib/trade/trade_stats_calendar_lib.py similarity index 100% rename from trade_stats_calendar_lib.py rename to lib/trade/trade_stats_calendar_lib.py diff --git a/manual_trading_hub/agent.py b/manual_trading_hub/agent.py index 9cf96d3..1a2b8e2 100644 --- a/manual_trading_hub/agent.py +++ b/manual_trading_hub/agent.py @@ -1,908 +1,908 @@ -""" -子账户极轻代理:GET /status、挂单/条件单查询与撤销、POST /emergency/close-all、POST /emergency/close-position,仅监听 127.0.0.1。 - -与仓库内四个策略/监控目录一一对应时,典型用法(各目录自己的 .env 里已有密钥;子代理用环境变量 PORT,勿与 Flask 的 APP_PORT 相同): - EXCHANGE=binance → crypto_monitor_binance(BINANCE_*) - EXCHANGE=okx → crypto_monitor_okx(OKX_*) - EXCHANGE=gate → crypto_monitor_gate / crypto_monitor_gate_bot(GATE_*) - -环境变量: - EXCHANGE binance(默认)| okx | gate - PORT 默认 15200(与 crypto_monitor_* 的 Flask APP_PORT 错开;中控默认聚合 15200–15203) - HOST 默认 127.0.0.1 - CONTROL_TOKEN 可选;请求头 X-Control-Token - -Binance:BINANCE_API_KEY / BINANCE_API_SECRET;余额为 **U 本位永续合约账户** USDT(与 `crypto_monitor_binance` 的合约口径一致,非现货钱包);BINANCE_POSITION_MODE;BINANCE_MARGIN_MODE -OKX:OKX_API_KEY / OKX_API_SECRET / OKX_API_PASSPHRASE;OKX_TD_MODE;OKX_POS_MODE -Gate:GATE_API_KEY / GATE_API_SECRET;GATE_TD_MODE;GATE_POS_MODE - -代理与主项目一致时可设:BINANCE_SOCKS_PROXY / OKX_SOCKS_PROXY / GATE_SOCKS_PROXY(或 HTTP(S)_PROXY)。 -""" -from __future__ import annotations - -import math -import os -import sys -import time -from pathlib import Path -from typing import Any - -_REPO_ROOT = Path(__file__).resolve().parents[1] -if str(_REPO_ROOT) not in sys.path: - sys.path.insert(0, str(_REPO_ROOT)) -from hub_ohlcv_lib import format_price_by_tick, price_tick_from_market -from hub_position_metrics import ( - parse_position_entry_price, - parse_position_mark_price, - parse_position_unrealized_pnl, - resolve_position_display_upnl, -) - -import ccxt -from fastapi import FastAPI, Header, HTTPException, Request -from fastapi.responses import JSONResponse -from pydantic import BaseModel - -from exchange_orders import ( - attach_orders_to_positions, - cancel_order as hub_cancel_order, - cancel_orders_for_symbol, - list_open_orders, - replace_position_tpsl, - symbols_match, -) - -HOST = os.getenv("HOST", "127.0.0.1") -PORT = int(os.getenv("PORT", "15200")) -CONTROL_TOKEN = (os.getenv("CONTROL_TOKEN") or "").strip() - -_raw_ex = (os.getenv("EXCHANGE") or "binance").strip().lower() -if _raw_ex in ("binance", "bnb", "ba"): - EXCHANGE_KIND = "binance" -elif _raw_ex in ("okx", "okex"): - EXCHANGE_KIND = "okx" -elif _raw_ex in ("gate", "gateio"): - EXCHANGE_KIND = "gate" -else: - EXCHANGE_KIND = "binance" - -# —— Binance —— -_bin_pos = (os.getenv("BINANCE_POSITION_MODE") or "hedge").strip().lower() -BINANCE_POSITION_MODE = "hedge" if _bin_pos in ("hedge", "dual", "double", "hedged") else "oneway" -_bin_margin = (os.getenv("BINANCE_MARGIN_MODE") or "cross").strip().lower() -BINANCE_DEFAULT_MARGIN_MODE = "cross" if _bin_margin in ("cross", "cross_margin") else "isolated" - -# —— OKX —— -OKX_TD_MODE = (os.getenv("OKX_TD_MODE") or "cross").strip() -_okx_pos = (os.getenv("OKX_POS_MODE") or "hedge").strip().lower() -OKX_POS_MODE = "hedge" if _okx_pos in ("hedge", "long_short_mode", "dual") else "net" - -# —— Gate —— -_gate_td = (os.getenv("GATE_TD_MODE") or "cross").strip().lower() -GATE_DEFAULT_MARGIN_MODE = "cross" if _gate_td in ("cross", "cross_margin") else "isolated" -_gate_pos = (os.getenv("GATE_POS_MODE") or "hedge").strip().lower() -GATE_POS_MODE = "hedge" if _gate_pos in ("hedge", "dual", "double") else "single" - -app = FastAPI(title="sub-agent", docs_url=None, redoc_url=None) -_ccxt_ex: Any = None -_markets_loaded = False - - -def _socks_proxy_url(prefix: str) -> str: - return (os.getenv(f"{prefix}_SOCKS_PROXY") or "").strip() - - -def _http_https_proxy(prefix: str) -> dict[str, str] | None: - http = (os.getenv(f"{prefix}_HTTP_PROXY") or "").strip() - https = (os.getenv(f"{prefix}_HTTPS_PROXY") or "").strip() - socks = _socks_proxy_url(prefix) - if socks: - return {"http": socks, "https": socks} - if http or https: - return {"http": http, "https": https} - return None - - -def _attach_proxies(ex: Any, prefix: str) -> None: - p = _http_https_proxy(prefix) - if p: - ex.proxies = p - - -def _make_exchange() -> Any: - if EXCHANGE_KIND == "binance": - key = (os.getenv("BINANCE_API_KEY") or "").strip() - secret = (os.getenv("BINANCE_API_SECRET") or "").strip() - if not key or not secret: - raise RuntimeError("缺少 BINANCE_API_KEY / BINANCE_API_SECRET") - ex = ccxt.binance( - { - "apiKey": key, - "secret": secret, - "enableRateLimit": True, - "options": { - "defaultType": "swap", - # ccxt 默认 fetch_balance 走现货;与监控项目一致,固定为 U 本位合约钱包 - "fetchBalance": {"defaultType": "swap"}, - "defaultMarginMode": BINANCE_DEFAULT_MARGIN_MODE, - "adjustForTimeDifference": True, - }, - } - ) - _attach_proxies(ex, "BINANCE") - return ex - - if EXCHANGE_KIND == "okx": - key = (os.getenv("OKX_API_KEY") or "").strip() - secret = (os.getenv("OKX_API_SECRET") or "").strip() - password = (os.getenv("OKX_API_PASSPHRASE") or "").strip() - if not key or not secret or not password: - raise RuntimeError("缺少 OKX_API_KEY / OKX_API_SECRET / OKX_API_PASSPHRASE") - ex = ccxt.okx( - { - "apiKey": key, - "secret": secret, - "password": password, - "enableRateLimit": True, - "options": { - "defaultType": "swap", - "hedged": OKX_POS_MODE == "hedge", - }, - } - ) - _attach_proxies(ex, "OKX") - return ex - - # gate - key = (os.getenv("GATE_API_KEY") or "").strip() - secret = (os.getenv("GATE_API_SECRET") or "").strip() - if not key or not secret: - raise RuntimeError("缺少 GATE_API_KEY / GATE_API_SECRET") - ex = ccxt.gateio( - { - "apiKey": key, - "secret": secret, - "enableRateLimit": True, - "options": { - "defaultType": "swap", - "defaultMarginMode": GATE_DEFAULT_MARGIN_MODE, - }, - } - ) - _attach_proxies(ex, "GATE") - return ex - - -def get_exchange() -> Any: - global _ccxt_ex - if _ccxt_ex is None: - _ccxt_ex = _make_exchange() - return _ccxt_ex - - -def _ensure_markets() -> None: - global _markets_loaded - if not _markets_loaded: - get_exchange().load_markets() - _markets_loaded = True - - -def _check_token(x_control_token: str | None) -> None: - if not CONTROL_TOKEN: - return - if (x_control_token or "").strip() != CONTROL_TOKEN: - raise HTTPException(status_code=401, detail="invalid token") - - -def _position_mode_label() -> str: - if EXCHANGE_KIND == "binance": - return BINANCE_POSITION_MODE - if EXCHANGE_KIND == "okx": - return OKX_POS_MODE - return GATE_POS_MODE - - -def _close_param_candidates_binance(direction: str) -> list[dict[str, Any]]: - ps = "LONG" if direction == "long" else "SHORT" - hedge_ro = {"positionSide": ps, "reduceOnly": True} - hedge_plain = {"positionSide": ps} - oneway_ro = {"reduceOnly": True} - oneway_plain: dict[str, Any] = {} - if BINANCE_POSITION_MODE == "hedge": - return [hedge_ro, hedge_plain, oneway_ro, oneway_plain] - return [oneway_ro, oneway_plain, hedge_ro, hedge_plain] - - -def _close_param_candidates_okx(direction: str) -> list[dict[str, Any]]: - base: dict[str, Any] = {"tdMode": OKX_TD_MODE} - out: list[dict[str, Any]] = [] - if OKX_POS_MODE == "hedge": - ps = "long" if direction == "long" else "short" - out.extend( - [ - {**base, "posSide": ps, "reduceOnly": True}, - {**base, "posSide": ps}, - ] - ) - out.extend([{**base, "reduceOnly": True}, dict(base)]) - return out - - -def _close_param_candidates_gate(_direction: str) -> list[dict[str, Any]]: - return [{"reduceOnly": True}, {}] - - -def _close_param_candidates(direction: str) -> list[dict[str, Any]]: - if EXCHANGE_KIND == "binance": - return _close_param_candidates_binance(direction) - if EXCHANGE_KIND == "okx": - return _close_param_candidates_okx(direction) - return _close_param_candidates_gate(direction) - - -def _retryable_close_err(msg: str) -> bool: - s = (msg or "").lower() - if "-4061" in s: - return True - if "-1106" in s and "reduceonly" in s: - return True - if "reduceonly" in s or "reduce only" in s: - return True - if "position side" in s or "positionside" in s or "pos side" in s: - return True - if "dual side" in s or "position mode" in s: - return True - return False - - -def _position_contracts(p: dict[str, Any]) -> float: - raw = p.get("contracts") - if raw is not None: - try: - return float(raw) - except (TypeError, ValueError): - pass - info = p.get("info") or {} - for k in ("positionAmt", "positionamt", "pos", "size"): - if k in info: - try: - v = float(info[k]) - if v != 0: - return v - except (TypeError, ValueError): - pass - return 0.0 - - -def _position_side(p: dict[str, Any], contracts: float) -> str: - s = (p.get("side") or "").lower() - if s in ("long", "short"): - return s - if contracts > 0: - return "long" - if contracts < 0: - return "short" - return "long" - - -def _cancel_symbol_orders(ex: Any, sym: str) -> None: - try: - ex.cancel_all_orders(sym, params={}) - except Exception: - pass - if EXCHANGE_KIND != "binance": - return - try: - m = ex.market(sym) - cid = m.get("id") - if cid and hasattr(ex, "fapiPrivateDeleteAlgoOpenOrders"): - ex.fapiPrivateDeleteAlgoOpenOrders({"symbol": cid}) - except Exception: - pass - - -class EmergencyClosePositionBody(BaseModel): - symbol: str - side: str - - -class CancelOrderBody(BaseModel): - symbol: str - order_id: str - channel: str = "regular" - - -class CancelSymbolOrdersBody(BaseModel): - symbol: str - scope: str = "all" # all | conditional | limit - - -class PlaceTpslBody(BaseModel): - symbol: str - side: str # long | short - stop_loss: float - take_profit: float - contracts: float | None = None - - -def _close_position_market( - ex: Any, sym: str, side: str, contracts: float -) -> tuple[dict[str, Any] | None, str | None]: - """市价平掉指定合约、方向;返回 (closed_info, error_message)。""" - side_n = (side or "").strip().lower() - if side_n not in ("long", "short"): - return None, f"无效方向: {side}" - close_side = "sell" if side_n == "long" else "buy" - direction = side_n - try: - amt = float(ex.amount_to_precision(sym, abs(float(contracts)))) - except Exception: - amt = abs(float(contracts)) - if amt <= 0: - return None, f"{sym}: 可平张数为 0" - order_resp = None - last_err: Exception | None = None - for params in _close_param_candidates(direction): - try: - order_resp = ex.create_order(sym, "market", close_side, amt, None, params) - last_err = None - break - except Exception as e: - last_err = e - if _retryable_close_err(str(e)): - continue - return None, f"{sym}: {e}" - if order_resp is None: - return None, f"{sym}: {last_err or '下单失败'}" - _cancel_symbol_orders(ex, sym) - return ( - {"symbol": sym, "side": side_n, "amount": amt, "order_id": order_resp.get("id")}, - None, - ) - - -def _is_local(host: str | None) -> bool: - if not host: - return False - h = host.lower() - return h in ("127.0.0.1", "::1", "localhost") or h.startswith("::ffff:127.0.0.1") - - -def _finite_or_none(x: Any) -> float | None: - try: - f = float(x) - return f if math.isfinite(f) else None - except (TypeError, ValueError): - return None - - -def _position_price_fmt(ex: Any, symbol: str, price: float | None) -> tuple[float | None, str | None, float | None]: - """返回 (原价, 交易所精度字符串, price_tick)。""" - if price is None or price <= 0 or not symbol: - return None, None, None - tick: float | None = None - try: - ex.load_markets() - unified = ex.market(symbol)["symbol"] - tick = price_tick_from_market(ex, unified) - px_str = str(ex.price_to_precision(unified, price)) - return _finite_or_none(float(px_str)), px_str, tick - except Exception: - return price, format_price_by_tick(price, tick), tick - - -def _position_entry_price(p: dict[str, Any]) -> float | None: - """四所 ccxt 持仓统一解析开仓均价(Binance/OKX/Gate 字段名不一致)。""" - return parse_position_entry_price(p) - - -def _position_contract_size(ex: Any, symbol: str) -> float: - try: - market = ex.market((symbol or "").strip()) - cs = float(market.get("contractSize") or 1) - return cs if cs > 0 else 1.0 - except Exception: - return 1.0 - - -def _position_mark_price(p: dict[str, Any]) -> float | None: - """四所 ccxt 持仓统一解析标记价(与实例 parse_ccxt_position_metrics 一致)。""" - return parse_position_mark_price(p) - - -def _ticker_mark_price(ex: Any, symbol: str) -> float | None: - """持仓行无 mark 时,用 ticker 补标记价(last/mark)。""" - sym = (symbol or "").strip() - if not sym: - return None - try: - t = ex.fetch_ticker(sym) - except Exception: - return None - if not isinstance(t, dict): - return None - info = t.get("info") if isinstance(t.get("info"), dict) else {} - for key in ( - t.get("mark"), - t.get("last"), - t.get("close"), - info.get("markPrice"), - info.get("mark_price"), - info.get("markPx"), - ): - px = _finite_or_none(key) - if px is not None and px > 0: - return px - return None - - -def _extract_usdt_total(balance: dict[str, Any]) -> float | None: - """从 ccxt balance 结构中尽量取出 USDT 总额(与 crypto_monitor_binance 一致)。""" - usdt_info = balance.get("USDT") or {} - if not isinstance(usdt_info, dict): - usdt_info = {} - total_map = balance.get("total") or {} - if not isinstance(total_map, dict): - total_map = {} - free_map = balance.get("free") or {} - if not isinstance(free_map, dict): - free_map = {} - total = usdt_info.get("total") - if total is None: - total = usdt_info.get("equity") - if total is None: - total = total_map.get("USDT") - if total is None: - total = usdt_info.get("free") - if total is None: - total = free_map.get("USDT") - try: - return float(total) if total is not None else None - except (TypeError, ValueError): - return None - - -def _binance_futures_usdt_asset_row(balance: Any) -> dict[str, Any] | None: - """U 本位合约 fetch_balance(type=swap) 的 info.assets 中 USDT 一行(与币安合约后台口径一致)。""" - if not isinstance(balance, dict): - return None - info = balance.get("info") - if not isinstance(info, dict): - return None - assets = info.get("assets") - if not isinstance(assets, list): - return None - for a in assets: - if isinstance(a, dict) and str(a.get("asset") or "").upper() == "USDT": - return a - return None - - -def _binance_swap_usdt_total(ex: Any) -> float | None: - """仅 U 本位永续合约账户 USDT(显式 type=swap,不用现货余额)。""" - try: - bal = ex.fetch_balance({"type": "swap"}) - except Exception: - return None - row = _binance_futures_usdt_asset_row(bal) - if row: - for k in ("marginBalance", "walletBalance", "crossWalletBalance", "balance"): - x = row.get(k) - if x is not None and str(x).strip() != "": - try: - fv = float(x) - if fv >= 0: - return fv - except (TypeError, ValueError): - pass - v = _extract_usdt_total(bal) - return float(v) if v is not None else None - - -@app.middleware("http") -async def local_only(request: Request, call_next): - if request.client and not _is_local(request.client.host): - return JSONResponse({"detail": "forbidden"}, status_code=403) - return await call_next(request) - - -@app.get("/health") -def health(): - return {"ok": True, "exchange": EXCHANGE_KIND} - - -@app.get("/status") -def status(x_control_token: str | None = Header(default=None, alias="X-Control-Token")): - try: - return _status_inner(x_control_token) - except HTTPException: - raise - except Exception as e: - return JSONResponse( - { - "ok": False, - "error": f"status: {e}", - "exchange": EXCHANGE_KIND, - "balance_usdt": None, - "positions": [], - "total_unrealized_pnl": None, - }, - status_code=200, - ) - - -def _status_inner(x_control_token: str | None) -> Any: - _check_token(x_control_token) - try: - ex = get_exchange() - except RuntimeError as e: - return JSONResponse( - { - "ok": False, - "error": str(e), - "exchange": EXCHANGE_KIND, - "balance_usdt": None, - "positions": [], - "total_unrealized_pnl": None, - }, - status_code=200, - ) - try: - _ensure_markets() - except Exception as e: - return JSONResponse( - { - "ok": False, - "error": f"load_markets: {e}", - "exchange": EXCHANGE_KIND, - "balance_usdt": None, - "positions": [], - "total_unrealized_pnl": None, - }, - status_code=200, - ) - balance_usdt: float | None = None - try: - if EXCHANGE_KIND == "binance": - balance_usdt = _binance_swap_usdt_total(ex) - else: - bal = ex.fetch_balance() - u = bal.get("USDT") or {} - if isinstance(u, dict) and u.get("total") is not None: - balance_usdt = _finite_or_none(u["total"]) - except Exception: - pass - - positions_out: list[dict[str, Any]] = [] - total_upnl = 0.0 - try: - raw = ex.fetch_positions() or [] - except Exception as e: - return JSONResponse( - { - "ok": False, - "error": str(e), - "exchange": EXCHANGE_KIND, - "balance_usdt": balance_usdt, - "positions": [], - "total_unrealized_pnl": None, - }, - status_code=200, - ) - - for p in raw: - if not isinstance(p, dict): - continue - c = _position_contracts(p) - if abs(c) < 1e-12: - continue - sym = p.get("symbol") or "" - side = _position_side(p, c) - entry_f = _position_entry_price(p) - mark_f = _position_mark_price(p) - if mark_f is None and sym: - mark_f = _ticker_mark_price(ex, sym) - cs = _position_contract_size(ex, sym) if sym else 1.0 - exchange_upnl = parse_position_unrealized_pnl(p) - upnl_f = resolve_position_display_upnl( - side, - entry_f, - mark_f, - abs(c), - cs, - exchange_upnl, - ) - if upnl_f is None: - upnl_f = 0.0 - total_upnl += upnl_f - notional = p.get("notional") - try: - notional_f = float(notional) if notional is not None else None - except (TypeError, ValueError): - notional_f = None - _, entry_fmt, price_tick = _position_price_fmt(ex, sym, entry_f) - _, mark_fmt, mark_tick = _position_price_fmt(ex, sym, mark_f) - if price_tick is None and mark_tick is not None: - price_tick = mark_tick - positions_out.append( - { - "symbol": sym, - "side": side, - "contracts": abs(c), - "contracts_signed": c, - "notional_usdt": _finite_or_none(notional_f) if notional_f is not None else None, - "unrealized_pnl": _finite_or_none(upnl_f), - "entry_price": entry_f, - "entry_price_fmt": entry_fmt, - "mark_price": mark_f, - "mark_price_fmt": mark_fmt, - "contract_size": _finite_or_none(cs), - "price_tick": _finite_or_none(price_tick) if price_tick is not None else None, - } - ) - - orders_fetch_error: str | None = None - try: - attach_orders_to_positions( - positions_out, - list_open_orders(ex, EXCHANGE_KIND, None), - ) - except Exception as e: - orders_fetch_error = str(e) - for p in positions_out: - p.setdefault("conditional_orders", []) - p.setdefault("regular_orders", []) - - try: - pm = _position_mode_label() - except Exception: - pm = EXCHANGE_KIND - out = { - "ok": True, - "exchange": EXCHANGE_KIND, - "balance_usdt": balance_usdt, - "positions": positions_out, - "total_unrealized_pnl": _finite_or_none(total_upnl), - "position_mode": pm, - } - if orders_fetch_error: - out["orders_fetch_error"] = orders_fetch_error - return out - - -@app.get("/open-orders") -def open_orders( - symbol: str = "", - x_control_token: str | None = Header(default=None, alias="X-Control-Token"), -): - _check_token(x_control_token) - try: - ex = get_exchange() - _ensure_markets() - sym = (symbol or "").strip() or None - orders = list_open_orders(ex, EXCHANGE_KIND, sym) - return {"ok": True, "exchange": EXCHANGE_KIND, "symbol": sym, "orders": orders} - except Exception as e: - return JSONResponse( - {"ok": False, "error": str(e), "exchange": EXCHANGE_KIND, "orders": []}, - status_code=200, - ) - - -@app.post("/orders/cancel") -def cancel_one_order( - body: CancelOrderBody, - x_control_token: str | None = Header(default=None, alias="X-Control-Token"), -): - _check_token(x_control_token) - sym = (body.symbol or "").strip() - oid = (body.order_id or "").strip() - if not sym or not oid: - raise HTTPException(status_code=400, detail="symbol 与 order_id 必填") - try: - ex = get_exchange() - _ensure_markets() - hub_cancel_order(ex, EXCHANGE_KIND, sym, oid, body.channel or "regular") - return {"ok": True, "exchange": EXCHANGE_KIND, "cancelled": {"symbol": sym, "order_id": oid}} - except Exception as e: - return JSONResponse( - {"ok": False, "error": str(e), "exchange": EXCHANGE_KIND}, - status_code=200, - ) - - -@app.post("/orders/cancel-symbol") -def cancel_symbol_orders( - body: CancelSymbolOrdersBody, - x_control_token: str | None = Header(default=None, alias="X-Control-Token"), -): - _check_token(x_control_token) - sym = (body.symbol or "").strip() - if not sym: - raise HTTPException(status_code=400, detail="symbol 必填") - scope = (body.scope or "all").strip().lower() - if scope not in ("all", "conditional", "limit"): - raise HTTPException(status_code=400, detail="scope 须为 all / conditional / limit") - try: - ex = get_exchange() - _ensure_markets() - n = cancel_orders_for_symbol(ex, EXCHANGE_KIND, sym, scope=scope) - return {"ok": True, "exchange": EXCHANGE_KIND, "cancelled_count": n, "scope": scope} - except Exception as e: - return JSONResponse( - {"ok": False, "error": str(e), "exchange": EXCHANGE_KIND, "cancelled_count": 0}, - status_code=200, - ) - - -@app.post("/orders/place-tpsl") -def place_tpsl_orders( - body: PlaceTpslBody, - x_control_token: str | None = Header(default=None, alias="X-Control-Token"), -): - """先撤该合约全部条件单,再挂止盈+止损(与四实例策略逻辑一致)。""" - _check_token(x_control_token) - sym = (body.symbol or "").strip() - side = (body.side or "").strip().lower() - if not sym or side not in ("long", "short"): - raise HTTPException(status_code=400, detail="symbol 与 side(long/short) 必填") - try: - sl = float(body.stop_loss) - tp = float(body.take_profit) - except (TypeError, ValueError) as e: - raise HTTPException(status_code=400, detail="stop_loss / take_profit 须为数字") from e - try: - ex = get_exchange() - _ensure_markets() - amt = body.contracts - if amt is None or float(amt) <= 0: - raw = ex.fetch_positions() or [] - found = None - for p in raw: - psym = p.get("symbol") or "" - if not symbols_match(sym, psym): - continue - c = abs(float(p.get("contracts") or 0)) - if c <= 0: - continue - ps = (p.get("side") or "").lower() - if ps and ps != side: - continue - found = c - break - if found is None: - return JSONResponse( - {"ok": False, "error": f"未找到持仓 {sym} {side}", "exchange": EXCHANGE_KIND}, - status_code=200, - ) - amt = found - info = replace_position_tpsl(ex, EXCHANGE_KIND, sym, side, float(amt), sl, tp) - return {"ok": True, "exchange": EXCHANGE_KIND, "placed": info} - except HTTPException: - raise - except Exception as e: - return JSONResponse( - {"ok": False, "error": str(e), "exchange": EXCHANGE_KIND}, - status_code=200, - ) - - -@app.post("/emergency/close-all") -def emergency_close_all(x_control_token: str | None = Header(default=None, alias="X-Control-Token")): - _check_token(x_control_token) - try: - ex = get_exchange() - except RuntimeError as e: - raise HTTPException(status_code=503, detail=str(e)) from e - try: - _ensure_markets() - except Exception as e: - return JSONResponse( - {"ok": False, "error": f"load_markets: {e}", "closed": [], "errors": [str(e)], "exchange": EXCHANGE_KIND}, - status_code=200, - ) - errors: list[str] = [] - closed: list[dict[str, Any]] = [] - - try: - raw = ex.fetch_positions() or [] - except Exception as e: - raise HTTPException(status_code=502, detail=f"fetch_positions: {e}") from e - - for p in raw: - if not isinstance(p, dict): - continue - c = _position_contracts(p) - if abs(c) < 1e-12: - continue - sym = p.get("symbol") - if not sym: - continue - side = _position_side(p, c) - info, err = _close_position_market(ex, sym, side, abs(c)) - if err: - errors.append(err) - elif info: - closed.append(info) - time.sleep(0.05) - - return {"ok": len(errors) == 0, "closed": closed, "errors": errors, "exchange": EXCHANGE_KIND} - - -@app.post("/emergency/close-position") -def emergency_close_position( - body: EmergencyClosePositionBody, - x_control_token: str | None = Header(default=None, alias="X-Control-Token"), -): - _check_token(x_control_token) - sym = (body.symbol or "").strip() - want_side = (body.side or "").strip().lower() - if not sym: - raise HTTPException(status_code=400, detail="symbol 不能为空") - if want_side not in ("long", "short"): - raise HTTPException(status_code=400, detail="side 须为 long 或 short") - try: - ex = get_exchange() - except RuntimeError as e: - raise HTTPException(status_code=503, detail=str(e)) from e - try: - _ensure_markets() - except Exception as e: - return JSONResponse( - { - "ok": False, - "error": f"load_markets: {e}", - "closed": None, - "exchange": EXCHANGE_KIND, - }, - status_code=200, - ) - try: - raw = ex.fetch_positions() or [] - except Exception as e: - raise HTTPException(status_code=502, detail=f"fetch_positions: {e}") from e - - matched = None - for p in raw: - if not isinstance(p, dict): - continue - if (p.get("symbol") or "").strip() != sym: - continue - c = _position_contracts(p) - if abs(c) < 1e-12: - continue - side = _position_side(p, c) - if side != want_side: - continue - matched = (sym, side, abs(c)) - break - - if not matched: - return JSONResponse( - { - "ok": False, - "error": f"未找到持仓: {sym} {want_side}", - "closed": None, - "exchange": EXCHANGE_KIND, - }, - status_code=200, - ) - - sym, side, c = matched - info, err = _close_position_market(ex, sym, side, c) - if err: - return JSONResponse( - {"ok": False, "error": err, "closed": None, "exchange": EXCHANGE_KIND}, - status_code=200, - ) - return {"ok": True, "closed": info, "errors": [], "exchange": EXCHANGE_KIND} - - -def main(): - import uvicorn - - uvicorn.run(app, host=HOST, port=PORT, log_level="warning", access_log=False) - - -if __name__ == "__main__": - main() +""" +子账户极轻代理:GET /status、挂单/条件单查询与撤销、POST /emergency/close-all、POST /emergency/close-position,仅监听 127.0.0.1。 + +与仓库内四个策略/监控目录一一对应时,典型用法(各目录自己的 .env 里已有密钥;子代理用环境变量 PORT,勿与 Flask 的 APP_PORT 相同): + EXCHANGE=binance → crypto_monitor_binance(BINANCE_*) + EXCHANGE=okx → crypto_monitor_okx(OKX_*) + EXCHANGE=gate → crypto_monitor_gate / crypto_monitor_gate_bot(GATE_*) + +环境变量: + EXCHANGE binance(默认)| okx | gate + PORT 默认 15200(与 crypto_monitor_* 的 Flask APP_PORT 错开;中控默认聚合 15200–15203) + HOST 默认 127.0.0.1 + CONTROL_TOKEN 可选;请求头 X-Control-Token + +Binance:BINANCE_API_KEY / BINANCE_API_SECRET;余额为 **U 本位永续合约账户** USDT(与 `crypto_monitor_binance` 的合约口径一致,非现货钱包);BINANCE_POSITION_MODE;BINANCE_MARGIN_MODE +OKX:OKX_API_KEY / OKX_API_SECRET / OKX_API_PASSPHRASE;OKX_TD_MODE;OKX_POS_MODE +Gate:GATE_API_KEY / GATE_API_SECRET;GATE_TD_MODE;GATE_POS_MODE + +代理与主项目一致时可设:BINANCE_SOCKS_PROXY / OKX_SOCKS_PROXY / GATE_SOCKS_PROXY(或 HTTP(S)_PROXY)。 +""" +from __future__ import annotations + +import math +import os +import sys +import time +from pathlib import Path +from typing import Any + +_REPO_ROOT = Path(__file__).resolve().parents[1] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) +from lib.hub.hub_ohlcv_lib import format_price_by_tick, price_tick_from_market +from lib.hub.hub_position_metrics import ( + parse_position_entry_price, + parse_position_mark_price, + parse_position_unrealized_pnl, + resolve_position_display_upnl, +) + +import ccxt +from fastapi import FastAPI, Header, HTTPException, Request +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from exchange_orders import ( + attach_orders_to_positions, + cancel_order as hub_cancel_order, + cancel_orders_for_symbol, + list_open_orders, + replace_position_tpsl, + symbols_match, +) + +HOST = os.getenv("HOST", "127.0.0.1") +PORT = int(os.getenv("PORT", "15200")) +CONTROL_TOKEN = (os.getenv("CONTROL_TOKEN") or "").strip() + +_raw_ex = (os.getenv("EXCHANGE") or "binance").strip().lower() +if _raw_ex in ("binance", "bnb", "ba"): + EXCHANGE_KIND = "binance" +elif _raw_ex in ("okx", "okex"): + EXCHANGE_KIND = "okx" +elif _raw_ex in ("gate", "gateio"): + EXCHANGE_KIND = "gate" +else: + EXCHANGE_KIND = "binance" + +# —— Binance —— +_bin_pos = (os.getenv("BINANCE_POSITION_MODE") or "hedge").strip().lower() +BINANCE_POSITION_MODE = "hedge" if _bin_pos in ("hedge", "dual", "double", "hedged") else "oneway" +_bin_margin = (os.getenv("BINANCE_MARGIN_MODE") or "cross").strip().lower() +BINANCE_DEFAULT_MARGIN_MODE = "cross" if _bin_margin in ("cross", "cross_margin") else "isolated" + +# —— OKX —— +OKX_TD_MODE = (os.getenv("OKX_TD_MODE") or "cross").strip() +_okx_pos = (os.getenv("OKX_POS_MODE") or "hedge").strip().lower() +OKX_POS_MODE = "hedge" if _okx_pos in ("hedge", "long_short_mode", "dual") else "net" + +# —— Gate —— +_gate_td = (os.getenv("GATE_TD_MODE") or "cross").strip().lower() +GATE_DEFAULT_MARGIN_MODE = "cross" if _gate_td in ("cross", "cross_margin") else "isolated" +_gate_pos = (os.getenv("GATE_POS_MODE") or "hedge").strip().lower() +GATE_POS_MODE = "hedge" if _gate_pos in ("hedge", "dual", "double") else "single" + +app = FastAPI(title="sub-agent", docs_url=None, redoc_url=None) +_ccxt_ex: Any = None +_markets_loaded = False + + +def _socks_proxy_url(prefix: str) -> str: + return (os.getenv(f"{prefix}_SOCKS_PROXY") or "").strip() + + +def _http_https_proxy(prefix: str) -> dict[str, str] | None: + http = (os.getenv(f"{prefix}_HTTP_PROXY") or "").strip() + https = (os.getenv(f"{prefix}_HTTPS_PROXY") or "").strip() + socks = _socks_proxy_url(prefix) + if socks: + return {"http": socks, "https": socks} + if http or https: + return {"http": http, "https": https} + return None + + +def _attach_proxies(ex: Any, prefix: str) -> None: + p = _http_https_proxy(prefix) + if p: + ex.proxies = p + + +def _make_exchange() -> Any: + if EXCHANGE_KIND == "binance": + key = (os.getenv("BINANCE_API_KEY") or "").strip() + secret = (os.getenv("BINANCE_API_SECRET") or "").strip() + if not key or not secret: + raise RuntimeError("缺少 BINANCE_API_KEY / BINANCE_API_SECRET") + ex = ccxt.binance( + { + "apiKey": key, + "secret": secret, + "enableRateLimit": True, + "options": { + "defaultType": "swap", + # ccxt 默认 fetch_balance 走现货;与监控项目一致,固定为 U 本位合约钱包 + "fetchBalance": {"defaultType": "swap"}, + "defaultMarginMode": BINANCE_DEFAULT_MARGIN_MODE, + "adjustForTimeDifference": True, + }, + } + ) + _attach_proxies(ex, "BINANCE") + return ex + + if EXCHANGE_KIND == "okx": + key = (os.getenv("OKX_API_KEY") or "").strip() + secret = (os.getenv("OKX_API_SECRET") or "").strip() + password = (os.getenv("OKX_API_PASSPHRASE") or "").strip() + if not key or not secret or not password: + raise RuntimeError("缺少 OKX_API_KEY / OKX_API_SECRET / OKX_API_PASSPHRASE") + ex = ccxt.okx( + { + "apiKey": key, + "secret": secret, + "password": password, + "enableRateLimit": True, + "options": { + "defaultType": "swap", + "hedged": OKX_POS_MODE == "hedge", + }, + } + ) + _attach_proxies(ex, "OKX") + return ex + + # gate + key = (os.getenv("GATE_API_KEY") or "").strip() + secret = (os.getenv("GATE_API_SECRET") or "").strip() + if not key or not secret: + raise RuntimeError("缺少 GATE_API_KEY / GATE_API_SECRET") + ex = ccxt.gateio( + { + "apiKey": key, + "secret": secret, + "enableRateLimit": True, + "options": { + "defaultType": "swap", + "defaultMarginMode": GATE_DEFAULT_MARGIN_MODE, + }, + } + ) + _attach_proxies(ex, "GATE") + return ex + + +def get_exchange() -> Any: + global _ccxt_ex + if _ccxt_ex is None: + _ccxt_ex = _make_exchange() + return _ccxt_ex + + +def _ensure_markets() -> None: + global _markets_loaded + if not _markets_loaded: + get_exchange().load_markets() + _markets_loaded = True + + +def _check_token(x_control_token: str | None) -> None: + if not CONTROL_TOKEN: + return + if (x_control_token or "").strip() != CONTROL_TOKEN: + raise HTTPException(status_code=401, detail="invalid token") + + +def _position_mode_label() -> str: + if EXCHANGE_KIND == "binance": + return BINANCE_POSITION_MODE + if EXCHANGE_KIND == "okx": + return OKX_POS_MODE + return GATE_POS_MODE + + +def _close_param_candidates_binance(direction: str) -> list[dict[str, Any]]: + ps = "LONG" if direction == "long" else "SHORT" + hedge_ro = {"positionSide": ps, "reduceOnly": True} + hedge_plain = {"positionSide": ps} + oneway_ro = {"reduceOnly": True} + oneway_plain: dict[str, Any] = {} + if BINANCE_POSITION_MODE == "hedge": + return [hedge_ro, hedge_plain, oneway_ro, oneway_plain] + return [oneway_ro, oneway_plain, hedge_ro, hedge_plain] + + +def _close_param_candidates_okx(direction: str) -> list[dict[str, Any]]: + base: dict[str, Any] = {"tdMode": OKX_TD_MODE} + out: list[dict[str, Any]] = [] + if OKX_POS_MODE == "hedge": + ps = "long" if direction == "long" else "short" + out.extend( + [ + {**base, "posSide": ps, "reduceOnly": True}, + {**base, "posSide": ps}, + ] + ) + out.extend([{**base, "reduceOnly": True}, dict(base)]) + return out + + +def _close_param_candidates_gate(_direction: str) -> list[dict[str, Any]]: + return [{"reduceOnly": True}, {}] + + +def _close_param_candidates(direction: str) -> list[dict[str, Any]]: + if EXCHANGE_KIND == "binance": + return _close_param_candidates_binance(direction) + if EXCHANGE_KIND == "okx": + return _close_param_candidates_okx(direction) + return _close_param_candidates_gate(direction) + + +def _retryable_close_err(msg: str) -> bool: + s = (msg or "").lower() + if "-4061" in s: + return True + if "-1106" in s and "reduceonly" in s: + return True + if "reduceonly" in s or "reduce only" in s: + return True + if "position side" in s or "positionside" in s or "pos side" in s: + return True + if "dual side" in s or "position mode" in s: + return True + return False + + +def _position_contracts(p: dict[str, Any]) -> float: + raw = p.get("contracts") + if raw is not None: + try: + return float(raw) + except (TypeError, ValueError): + pass + info = p.get("info") or {} + for k in ("positionAmt", "positionamt", "pos", "size"): + if k in info: + try: + v = float(info[k]) + if v != 0: + return v + except (TypeError, ValueError): + pass + return 0.0 + + +def _position_side(p: dict[str, Any], contracts: float) -> str: + s = (p.get("side") or "").lower() + if s in ("long", "short"): + return s + if contracts > 0: + return "long" + if contracts < 0: + return "short" + return "long" + + +def _cancel_symbol_orders(ex: Any, sym: str) -> None: + try: + ex.cancel_all_orders(sym, params={}) + except Exception: + pass + if EXCHANGE_KIND != "binance": + return + try: + m = ex.market(sym) + cid = m.get("id") + if cid and hasattr(ex, "fapiPrivateDeleteAlgoOpenOrders"): + ex.fapiPrivateDeleteAlgoOpenOrders({"symbol": cid}) + except Exception: + pass + + +class EmergencyClosePositionBody(BaseModel): + symbol: str + side: str + + +class CancelOrderBody(BaseModel): + symbol: str + order_id: str + channel: str = "regular" + + +class CancelSymbolOrdersBody(BaseModel): + symbol: str + scope: str = "all" # all | conditional | limit + + +class PlaceTpslBody(BaseModel): + symbol: str + side: str # long | short + stop_loss: float + take_profit: float + contracts: float | None = None + + +def _close_position_market( + ex: Any, sym: str, side: str, contracts: float +) -> tuple[dict[str, Any] | None, str | None]: + """市价平掉指定合约、方向;返回 (closed_info, error_message)。""" + side_n = (side or "").strip().lower() + if side_n not in ("long", "short"): + return None, f"无效方向: {side}" + close_side = "sell" if side_n == "long" else "buy" + direction = side_n + try: + amt = float(ex.amount_to_precision(sym, abs(float(contracts)))) + except Exception: + amt = abs(float(contracts)) + if amt <= 0: + return None, f"{sym}: 可平张数为 0" + order_resp = None + last_err: Exception | None = None + for params in _close_param_candidates(direction): + try: + order_resp = ex.create_order(sym, "market", close_side, amt, None, params) + last_err = None + break + except Exception as e: + last_err = e + if _retryable_close_err(str(e)): + continue + return None, f"{sym}: {e}" + if order_resp is None: + return None, f"{sym}: {last_err or '下单失败'}" + _cancel_symbol_orders(ex, sym) + return ( + {"symbol": sym, "side": side_n, "amount": amt, "order_id": order_resp.get("id")}, + None, + ) + + +def _is_local(host: str | None) -> bool: + if not host: + return False + h = host.lower() + return h in ("127.0.0.1", "::1", "localhost") or h.startswith("::ffff:127.0.0.1") + + +def _finite_or_none(x: Any) -> float | None: + try: + f = float(x) + return f if math.isfinite(f) else None + except (TypeError, ValueError): + return None + + +def _position_price_fmt(ex: Any, symbol: str, price: float | None) -> tuple[float | None, str | None, float | None]: + """返回 (原价, 交易所精度字符串, price_tick)。""" + if price is None or price <= 0 or not symbol: + return None, None, None + tick: float | None = None + try: + ex.load_markets() + unified = ex.market(symbol)["symbol"] + tick = price_tick_from_market(ex, unified) + px_str = str(ex.price_to_precision(unified, price)) + return _finite_or_none(float(px_str)), px_str, tick + except Exception: + return price, format_price_by_tick(price, tick), tick + + +def _position_entry_price(p: dict[str, Any]) -> float | None: + """四所 ccxt 持仓统一解析开仓均价(Binance/OKX/Gate 字段名不一致)。""" + return parse_position_entry_price(p) + + +def _position_contract_size(ex: Any, symbol: str) -> float: + try: + market = ex.market((symbol or "").strip()) + cs = float(market.get("contractSize") or 1) + return cs if cs > 0 else 1.0 + except Exception: + return 1.0 + + +def _position_mark_price(p: dict[str, Any]) -> float | None: + """四所 ccxt 持仓统一解析标记价(与实例 parse_ccxt_position_metrics 一致)。""" + return parse_position_mark_price(p) + + +def _ticker_mark_price(ex: Any, symbol: str) -> float | None: + """持仓行无 mark 时,用 ticker 补标记价(last/mark)。""" + sym = (symbol or "").strip() + if not sym: + return None + try: + t = ex.fetch_ticker(sym) + except Exception: + return None + if not isinstance(t, dict): + return None + info = t.get("info") if isinstance(t.get("info"), dict) else {} + for key in ( + t.get("mark"), + t.get("last"), + t.get("close"), + info.get("markPrice"), + info.get("mark_price"), + info.get("markPx"), + ): + px = _finite_or_none(key) + if px is not None and px > 0: + return px + return None + + +def _extract_usdt_total(balance: dict[str, Any]) -> float | None: + """从 ccxt balance 结构中尽量取出 USDT 总额(与 crypto_monitor_binance 一致)。""" + usdt_info = balance.get("USDT") or {} + if not isinstance(usdt_info, dict): + usdt_info = {} + total_map = balance.get("total") or {} + if not isinstance(total_map, dict): + total_map = {} + free_map = balance.get("free") or {} + if not isinstance(free_map, dict): + free_map = {} + total = usdt_info.get("total") + if total is None: + total = usdt_info.get("equity") + if total is None: + total = total_map.get("USDT") + if total is None: + total = usdt_info.get("free") + if total is None: + total = free_map.get("USDT") + try: + return float(total) if total is not None else None + except (TypeError, ValueError): + return None + + +def _binance_futures_usdt_asset_row(balance: Any) -> dict[str, Any] | None: + """U 本位合约 fetch_balance(type=swap) 的 info.assets 中 USDT 一行(与币安合约后台口径一致)。""" + if not isinstance(balance, dict): + return None + info = balance.get("info") + if not isinstance(info, dict): + return None + assets = info.get("assets") + if not isinstance(assets, list): + return None + for a in assets: + if isinstance(a, dict) and str(a.get("asset") or "").upper() == "USDT": + return a + return None + + +def _binance_swap_usdt_total(ex: Any) -> float | None: + """仅 U 本位永续合约账户 USDT(显式 type=swap,不用现货余额)。""" + try: + bal = ex.fetch_balance({"type": "swap"}) + except Exception: + return None + row = _binance_futures_usdt_asset_row(bal) + if row: + for k in ("marginBalance", "walletBalance", "crossWalletBalance", "balance"): + x = row.get(k) + if x is not None and str(x).strip() != "": + try: + fv = float(x) + if fv >= 0: + return fv + except (TypeError, ValueError): + pass + v = _extract_usdt_total(bal) + return float(v) if v is not None else None + + +@app.middleware("http") +async def local_only(request: Request, call_next): + if request.client and not _is_local(request.client.host): + return JSONResponse({"detail": "forbidden"}, status_code=403) + return await call_next(request) + + +@app.get("/health") +def health(): + return {"ok": True, "exchange": EXCHANGE_KIND} + + +@app.get("/status") +def status(x_control_token: str | None = Header(default=None, alias="X-Control-Token")): + try: + return _status_inner(x_control_token) + except HTTPException: + raise + except Exception as e: + return JSONResponse( + { + "ok": False, + "error": f"status: {e}", + "exchange": EXCHANGE_KIND, + "balance_usdt": None, + "positions": [], + "total_unrealized_pnl": None, + }, + status_code=200, + ) + + +def _status_inner(x_control_token: str | None) -> Any: + _check_token(x_control_token) + try: + ex = get_exchange() + except RuntimeError as e: + return JSONResponse( + { + "ok": False, + "error": str(e), + "exchange": EXCHANGE_KIND, + "balance_usdt": None, + "positions": [], + "total_unrealized_pnl": None, + }, + status_code=200, + ) + try: + _ensure_markets() + except Exception as e: + return JSONResponse( + { + "ok": False, + "error": f"load_markets: {e}", + "exchange": EXCHANGE_KIND, + "balance_usdt": None, + "positions": [], + "total_unrealized_pnl": None, + }, + status_code=200, + ) + balance_usdt: float | None = None + try: + if EXCHANGE_KIND == "binance": + balance_usdt = _binance_swap_usdt_total(ex) + else: + bal = ex.fetch_balance() + u = bal.get("USDT") or {} + if isinstance(u, dict) and u.get("total") is not None: + balance_usdt = _finite_or_none(u["total"]) + except Exception: + pass + + positions_out: list[dict[str, Any]] = [] + total_upnl = 0.0 + try: + raw = ex.fetch_positions() or [] + except Exception as e: + return JSONResponse( + { + "ok": False, + "error": str(e), + "exchange": EXCHANGE_KIND, + "balance_usdt": balance_usdt, + "positions": [], + "total_unrealized_pnl": None, + }, + status_code=200, + ) + + for p in raw: + if not isinstance(p, dict): + continue + c = _position_contracts(p) + if abs(c) < 1e-12: + continue + sym = p.get("symbol") or "" + side = _position_side(p, c) + entry_f = _position_entry_price(p) + mark_f = _position_mark_price(p) + if mark_f is None and sym: + mark_f = _ticker_mark_price(ex, sym) + cs = _position_contract_size(ex, sym) if sym else 1.0 + exchange_upnl = parse_position_unrealized_pnl(p) + upnl_f = resolve_position_display_upnl( + side, + entry_f, + mark_f, + abs(c), + cs, + exchange_upnl, + ) + if upnl_f is None: + upnl_f = 0.0 + total_upnl += upnl_f + notional = p.get("notional") + try: + notional_f = float(notional) if notional is not None else None + except (TypeError, ValueError): + notional_f = None + _, entry_fmt, price_tick = _position_price_fmt(ex, sym, entry_f) + _, mark_fmt, mark_tick = _position_price_fmt(ex, sym, mark_f) + if price_tick is None and mark_tick is not None: + price_tick = mark_tick + positions_out.append( + { + "symbol": sym, + "side": side, + "contracts": abs(c), + "contracts_signed": c, + "notional_usdt": _finite_or_none(notional_f) if notional_f is not None else None, + "unrealized_pnl": _finite_or_none(upnl_f), + "entry_price": entry_f, + "entry_price_fmt": entry_fmt, + "mark_price": mark_f, + "mark_price_fmt": mark_fmt, + "contract_size": _finite_or_none(cs), + "price_tick": _finite_or_none(price_tick) if price_tick is not None else None, + } + ) + + orders_fetch_error: str | None = None + try: + attach_orders_to_positions( + positions_out, + list_open_orders(ex, EXCHANGE_KIND, None), + ) + except Exception as e: + orders_fetch_error = str(e) + for p in positions_out: + p.setdefault("conditional_orders", []) + p.setdefault("regular_orders", []) + + try: + pm = _position_mode_label() + except Exception: + pm = EXCHANGE_KIND + out = { + "ok": True, + "exchange": EXCHANGE_KIND, + "balance_usdt": balance_usdt, + "positions": positions_out, + "total_unrealized_pnl": _finite_or_none(total_upnl), + "position_mode": pm, + } + if orders_fetch_error: + out["orders_fetch_error"] = orders_fetch_error + return out + + +@app.get("/open-orders") +def open_orders( + symbol: str = "", + x_control_token: str | None = Header(default=None, alias="X-Control-Token"), +): + _check_token(x_control_token) + try: + ex = get_exchange() + _ensure_markets() + sym = (symbol or "").strip() or None + orders = list_open_orders(ex, EXCHANGE_KIND, sym) + return {"ok": True, "exchange": EXCHANGE_KIND, "symbol": sym, "orders": orders} + except Exception as e: + return JSONResponse( + {"ok": False, "error": str(e), "exchange": EXCHANGE_KIND, "orders": []}, + status_code=200, + ) + + +@app.post("/orders/cancel") +def cancel_one_order( + body: CancelOrderBody, + x_control_token: str | None = Header(default=None, alias="X-Control-Token"), +): + _check_token(x_control_token) + sym = (body.symbol or "").strip() + oid = (body.order_id or "").strip() + if not sym or not oid: + raise HTTPException(status_code=400, detail="symbol 与 order_id 必填") + try: + ex = get_exchange() + _ensure_markets() + hub_cancel_order(ex, EXCHANGE_KIND, sym, oid, body.channel or "regular") + return {"ok": True, "exchange": EXCHANGE_KIND, "cancelled": {"symbol": sym, "order_id": oid}} + except Exception as e: + return JSONResponse( + {"ok": False, "error": str(e), "exchange": EXCHANGE_KIND}, + status_code=200, + ) + + +@app.post("/orders/cancel-symbol") +def cancel_symbol_orders( + body: CancelSymbolOrdersBody, + x_control_token: str | None = Header(default=None, alias="X-Control-Token"), +): + _check_token(x_control_token) + sym = (body.symbol or "").strip() + if not sym: + raise HTTPException(status_code=400, detail="symbol 必填") + scope = (body.scope or "all").strip().lower() + if scope not in ("all", "conditional", "limit"): + raise HTTPException(status_code=400, detail="scope 须为 all / conditional / limit") + try: + ex = get_exchange() + _ensure_markets() + n = cancel_orders_for_symbol(ex, EXCHANGE_KIND, sym, scope=scope) + return {"ok": True, "exchange": EXCHANGE_KIND, "cancelled_count": n, "scope": scope} + except Exception as e: + return JSONResponse( + {"ok": False, "error": str(e), "exchange": EXCHANGE_KIND, "cancelled_count": 0}, + status_code=200, + ) + + +@app.post("/orders/place-tpsl") +def place_tpsl_orders( + body: PlaceTpslBody, + x_control_token: str | None = Header(default=None, alias="X-Control-Token"), +): + """先撤该合约全部条件单,再挂止盈+止损(与四实例策略逻辑一致)。""" + _check_token(x_control_token) + sym = (body.symbol or "").strip() + side = (body.side or "").strip().lower() + if not sym or side not in ("long", "short"): + raise HTTPException(status_code=400, detail="symbol 与 side(long/short) 必填") + try: + sl = float(body.stop_loss) + tp = float(body.take_profit) + except (TypeError, ValueError) as e: + raise HTTPException(status_code=400, detail="stop_loss / take_profit 须为数字") from e + try: + ex = get_exchange() + _ensure_markets() + amt = body.contracts + if amt is None or float(amt) <= 0: + raw = ex.fetch_positions() or [] + found = None + for p in raw: + psym = p.get("symbol") or "" + if not symbols_match(sym, psym): + continue + c = abs(float(p.get("contracts") or 0)) + if c <= 0: + continue + ps = (p.get("side") or "").lower() + if ps and ps != side: + continue + found = c + break + if found is None: + return JSONResponse( + {"ok": False, "error": f"未找到持仓 {sym} {side}", "exchange": EXCHANGE_KIND}, + status_code=200, + ) + amt = found + info = replace_position_tpsl(ex, EXCHANGE_KIND, sym, side, float(amt), sl, tp) + return {"ok": True, "exchange": EXCHANGE_KIND, "placed": info} + except HTTPException: + raise + except Exception as e: + return JSONResponse( + {"ok": False, "error": str(e), "exchange": EXCHANGE_KIND}, + status_code=200, + ) + + +@app.post("/emergency/close-all") +def emergency_close_all(x_control_token: str | None = Header(default=None, alias="X-Control-Token")): + _check_token(x_control_token) + try: + ex = get_exchange() + except RuntimeError as e: + raise HTTPException(status_code=503, detail=str(e)) from e + try: + _ensure_markets() + except Exception as e: + return JSONResponse( + {"ok": False, "error": f"load_markets: {e}", "closed": [], "errors": [str(e)], "exchange": EXCHANGE_KIND}, + status_code=200, + ) + errors: list[str] = [] + closed: list[dict[str, Any]] = [] + + try: + raw = ex.fetch_positions() or [] + except Exception as e: + raise HTTPException(status_code=502, detail=f"fetch_positions: {e}") from e + + for p in raw: + if not isinstance(p, dict): + continue + c = _position_contracts(p) + if abs(c) < 1e-12: + continue + sym = p.get("symbol") + if not sym: + continue + side = _position_side(p, c) + info, err = _close_position_market(ex, sym, side, abs(c)) + if err: + errors.append(err) + elif info: + closed.append(info) + time.sleep(0.05) + + return {"ok": len(errors) == 0, "closed": closed, "errors": errors, "exchange": EXCHANGE_KIND} + + +@app.post("/emergency/close-position") +def emergency_close_position( + body: EmergencyClosePositionBody, + x_control_token: str | None = Header(default=None, alias="X-Control-Token"), +): + _check_token(x_control_token) + sym = (body.symbol or "").strip() + want_side = (body.side or "").strip().lower() + if not sym: + raise HTTPException(status_code=400, detail="symbol 不能为空") + if want_side not in ("long", "short"): + raise HTTPException(status_code=400, detail="side 须为 long 或 short") + try: + ex = get_exchange() + except RuntimeError as e: + raise HTTPException(status_code=503, detail=str(e)) from e + try: + _ensure_markets() + except Exception as e: + return JSONResponse( + { + "ok": False, + "error": f"load_markets: {e}", + "closed": None, + "exchange": EXCHANGE_KIND, + }, + status_code=200, + ) + try: + raw = ex.fetch_positions() or [] + except Exception as e: + raise HTTPException(status_code=502, detail=f"fetch_positions: {e}") from e + + matched = None + for p in raw: + if not isinstance(p, dict): + continue + if (p.get("symbol") or "").strip() != sym: + continue + c = _position_contracts(p) + if abs(c) < 1e-12: + continue + side = _position_side(p, c) + if side != want_side: + continue + matched = (sym, side, abs(c)) + break + + if not matched: + return JSONResponse( + { + "ok": False, + "error": f"未找到持仓: {sym} {want_side}", + "closed": None, + "exchange": EXCHANGE_KIND, + }, + status_code=200, + ) + + sym, side, c = matched + info, err = _close_position_market(ex, sym, side, c) + if err: + return JSONResponse( + {"ok": False, "error": err, "closed": None, "exchange": EXCHANGE_KIND}, + status_code=200, + ) + return {"ok": True, "closed": info, "errors": [], "exchange": EXCHANGE_KIND} + + +def main(): + import uvicorn + + uvicorn.run(app, host=HOST, port=PORT, log_level="warning", access_log=False) + + +if __name__ == "__main__": + main() diff --git a/manual_trading_hub/exchange_orders.py b/manual_trading_hub/exchange_orders.py index a194a1c..268565b 100644 --- a/manual_trading_hub/exchange_orders.py +++ b/manual_trading_hub/exchange_orders.py @@ -1,862 +1,862 @@ -""" -中控子代理:拉取交易所挂单/条件单并规范化展示;撤销单笔或按合约批量撤销;挂止盈止损(先撤条件单再挂)。 -""" -from __future__ import annotations - -import os -import time -from typing import Any - -from okx_orders_lib import fetch_okx_all_open_orders - - -def _coerce_float(*values) -> float | None: - for v in values: - if v is None or v == "": - continue - try: - return float(v) - except (TypeError, ValueError): - continue - return None - - -def _symbol_base_coin(symbol: str) -> str: - """ZEC/USDT:USDT、ZEC-USDT-SWAP 等统一为标的币 ZEC。""" - s = (symbol or "").strip().upper() - if not s: - return "" - if "-SWAP" in s: - s = s.replace("-SWAP", "") - if "-" in s: - return s.split("-", 1)[0] - if "/" in s: - return s.split("/", 1)[0] - if ":" in s: - return s.split(":", 1)[0] - return s - - -def symbols_match(position_symbol: str, order_symbol: str) -> bool: - a = (position_symbol or "").strip().upper() - b = (order_symbol or "").strip().upper() - if not a or not b: - return False - if a == b: - return True - ba, bb = _symbol_base_coin(a), _symbol_base_coin(b) - if ba and bb and ba == bb: - return True - for suf in (":USDT", "/USDT:USDT", "/USDT"): - a2 = a.replace(suf, "") - b2 = b.replace(suf, "") - if f"{a2}/USDT" == b or f"{a2}/USDT:USDT" == b: - return True - if f"{b2}/USDT" == a or f"{b2}/USDT:USDT" == a: - return True - if a2 == b2: - return True - return False - - -def _order_type_str(order: dict) -> str: - info = order.get("info") or {} - if isinstance(info, dict): - for key in ("orderType", "type", "origType", "algoType", "ordType"): - val = info.get(key) - if val: - return str(val).upper() - return str(order.get("type") or "").upper() - - -def _is_conditional_type(typ: str) -> bool: - t = (typ or "").upper() - if not t: - return False - keys = ("STOP", "TAKE_PROFIT", "TRAIL", "TRIGGER", "CONDITIONAL", "OCO") - return any(k in t for k in keys) - - -def _order_label(typ: str, side: str, reduce_only: bool | None) -> str: - t = (typ or "").upper() - side_l = (side or "").lower() - parts = [] - if "TAKE_PROFIT" in t: - parts.append("止盈") - elif "STOP" in t: - parts.append("止损") - elif "LIMIT" in t: - parts.append("限价") - elif "MARKET" in t: - parts.append("市价") - else: - parts.append(typ or "委托") - if side_l == "buy": - parts.append("买入") - elif side_l == "sell": - parts.append("卖出") - if reduce_only: - parts.append("·只减仓") - return " ".join(parts) - - -def _normalize_raw_order(order: dict, *, channel: str) -> dict[str, Any] | None: - if not isinstance(order, dict): - return None - info = order.get("info") or {} - if not isinstance(info, dict): - info = {} - oid = order.get("id") or info.get("algoId") or info.get("orderId") or info.get("ordId") - if oid is None: - return None - sym = str(order.get("symbol") or info.get("symbol") or info.get("instId") or "") - typ = _order_type_str(order) - side = str(order.get("side") or info.get("side") or "").lower() - reduce_only = order.get("reduceOnly") - if reduce_only is None: - reduce_only = info.get("reduceOnly") - try: - reduce_only = bool(reduce_only) if reduce_only is not None else None - except (TypeError, ValueError): - reduce_only = None - sl_trig = _coerce_float(info.get("slTriggerPx"), order.get("stopLossPrice")) - tp_trig = _coerce_float(info.get("tpTriggerPx"), order.get("takeProfitPrice")) - trig = _coerce_float( - order.get("stopPrice"), - order.get("triggerPrice"), - info.get("triggerPrice"), - info.get("stopPrice"), - info.get("triggerPx"), - sl_trig, - tp_trig, - ) - price = _coerce_float(order.get("price"), info.get("price"), info.get("ordPx")) - amt = _coerce_float(order.get("amount"), order.get("remaining"), info.get("quantity"), info.get("origQty"), info.get("sz")) - category = "conditional" if _is_conditional_type(typ) or channel == "algo" else "limit" - label = _order_label(typ, side, reduce_only) - if sl_trig is not None and tp_trig is not None: - label = f"止盈止损 SL={sl_trig:g} TP={tp_trig:g}" - elif sl_trig is not None: - label = f"止损 {sl_trig:g}" - elif tp_trig is not None: - label = f"止盈 {tp_trig:g}" - return { - "id": str(oid), - "symbol": sym, - "channel": channel, - "category": category, - "label": label, - "type": typ, - "side": side, - "amount": amt, - "trigger_price": trig, - "price": price, - "reduce_only": reduce_only, - "status": str(order.get("status") or info.get("status") or "open"), - } - - -def _okx_normalize_orders(raw: dict, channel: str) -> list[dict[str, Any]]: - """OKX 算法单常一笔同时含 SL+TP,拆成两条供中控「交易所止盈止损」展示。""" - n = _normalize_raw_order(dict(raw), channel=channel) - if not n: - return [] - info = raw.get("info") or {} - if not isinstance(info, dict): - info = {} - sl_trig = _coerce_float(info.get("slTriggerPx"), raw.get("stopLossPrice")) - tp_trig = _coerce_float(info.get("tpTriggerPx"), raw.get("takeProfitPrice")) - if sl_trig is None or tp_trig is None or sl_trig == tp_trig: - return [n] - base_id = n["id"] - rows: list[dict[str, Any]] = [] - for role, px, lbl in ( - ("sl", sl_trig, f"止损 {sl_trig:g}"), - ("tp", tp_trig, f"止盈 {tp_trig:g}"), - ): - row = dict(n) - row["id"] = f"{base_id}:{role}" - row["algo_id"] = base_id - row["label"] = lbl - row["trigger_price"] = px - row["category"] = "conditional" - row["channel"] = channel - rows.append(row) - return rows - - -def _okx_algo_order_id(order_id: str) -> str: - oid = str(order_id or "") - if ":" in oid: - return oid.split(":", 1)[0] - return oid - - -def _binance_list(ex: Any, symbol: str | None) -> list[dict]: - ex.load_markets() - out: list[dict] = [] - symbols: list[str] = [] - if symbol: - try: - symbols = [ex.market(symbol)["symbol"]] - except Exception: - symbols = [symbol] - else: - symbols = [] - try: - for p in ex.fetch_positions() or []: - sym = p.get("symbol") - if sym: - symbols.append(sym) - except Exception: - pass - if symbol and not symbols: - symbols = [symbol] - - def collect(ex_sym: str) -> None: - market = ex.market(ex_sym) - contract_id = market.get("id") - try: - for o in ex.fetch_open_orders(ex_sym) or []: - item = dict(o) - item["_channel"] = "regular" - n = _normalize_raw_order(item, channel="regular") - if n: - out.append(n) - except Exception: - pass - try: - if contract_id and hasattr(ex, "fapiPrivateGetOpenAlgoOrders"): - raw = ex.fapiPrivateGetOpenAlgoOrders({"symbol": contract_id}) - items = raw if isinstance(raw, list) else (raw.get("orders") or raw.get("data") or []) - for info in items or []: - if not isinstance(info, dict): - continue - wrapped = { - "id": info.get("algoId") or info.get("orderId"), - "symbol": ex_sym, - "info": info, - "type": info.get("orderType") or info.get("type"), - "side": (info.get("side") or "").lower(), - "amount": info.get("quantity") or info.get("origQty"), - "stopPrice": info.get("triggerPrice") or info.get("stopPrice"), - "reduceOnly": info.get("reduceOnly"), - } - n = _normalize_raw_order(wrapped, channel="algo") - if n: - out.append(n) - except Exception: - pass - - if symbols: - seen = set() - for s in symbols: - if s in seen: - continue - seen.add(s) - collect(s) - return out - - -def _okx_list(ex: Any, symbol: str | None) -> list[dict]: - ex.load_markets() - out: list[dict] = [] - symbols: list[str] = [] - if symbol: - try: - symbols = [ex.market(symbol)["symbol"]] - except Exception: - symbols = [symbol] - else: - try: - for p in ex.fetch_positions() or []: - sym = p.get("symbol") - if sym: - symbols.append(sym) - except Exception: - pass - if symbol and not symbols: - symbols = [symbol] - seen: set[tuple[str, str]] = set() - for sym in symbols: - try: - for o in fetch_okx_all_open_orders(ex, sym): - ch = "algo" if _is_conditional_type(_order_type_str(o)) else "regular" - for n in _okx_normalize_orders(dict(o), channel=ch): - key = (n["id"], n.get("channel") or ch) - if key in seen: - continue - seen.add(key) - out.append(n) - except Exception: - pass - return out - - -def _gate_extract_trigger_rule(info: dict) -> int | None: - if not isinstance(info, dict): - return None - trig = info.get("trigger") - if isinstance(trig, dict) and trig.get("rule") is not None: - try: - return int(trig["rule"]) - except (TypeError, ValueError): - pass - try: - return int(info.get("rule")) - except (TypeError, ValueError): - return None - - -def _gate_tpsl_role_from_rule(rule: int | None, direction: str) -> str | None: - if rule is None: - return None - d = (direction or "long").strip().lower() - if d == "long": - return "sl" if rule == 2 else ("tp" if rule == 1 else None) - return "sl" if rule == 1 else ("tp" if rule == 2 else None) - - -def _gate_trigger_params(ex: Any) -> dict: - p = {"type": "swap", "trigger": True} - try: - ex.load_unified_status() - if ex.options.get("unifiedAccount"): - p["unifiedAccount"] = True - except Exception: - pass - return p - - -def _gate_list(ex: Any, symbol: str | None) -> list[dict]: - ex.load_markets() - out: list[dict] = [] - symbols: list[str] = [] - if symbol: - try: - symbols = [ex.market(symbol)["symbol"]] - except Exception: - symbols = [symbol] - else: - try: - for p in ex.fetch_positions() or []: - sym = p.get("symbol") - if sym: - symbols.append(sym) - except Exception: - pass - if symbol and not symbols: - symbols = [symbol] - trig_params = _gate_trigger_params(ex) - seen = set() - for sym in symbols: - if sym in seen: - continue - seen.add(sym) - try: - for o in ex.fetch_open_orders(sym) or []: - n = _normalize_raw_order(dict(o), channel="regular") - if n: - out.append(n) - except Exception: - pass - try: - for o in ex.fetch_open_orders(sym, params=trig_params) or []: - item = dict(o) - item["type"] = item.get("type") or "trigger" - n = _normalize_raw_order(item, channel="algo") - if n: - info = o.get("info") if isinstance(o.get("info"), dict) else {} - rule = _gate_extract_trigger_rule(info) - if rule is not None: - n["gate_trigger_rule"] = rule - out.append(n) - except Exception: - pass - return out - - -def list_open_orders(ex: Any, exchange_kind: str, symbol: str | None = None) -> list[dict]: - kind = (exchange_kind or "binance").lower() - if kind == "binance": - orders = _binance_list(ex, symbol) - elif kind == "okx": - orders = _okx_list(ex, symbol) - else: - orders = _gate_list(ex, symbol) - if symbol: - orders = [o for o in orders if symbols_match(symbol, o.get("symbol") or "")] - # 去重 id+channel - seen: set[tuple[str, str]] = set() - uniq: list[dict] = [] - for o in orders: - key = (o["id"], o["channel"]) - if key in seen: - continue - seen.add(key) - uniq.append(o) - return uniq - - -def _enrich_gate_conditional_labels(cond: list[dict], side: str) -> None: - """Gate 仓位类触发单在 ccxt 中常显示为「市价·只减仓」,按 trigger.rule 标为止盈/止损。""" - direction = (side or "long").strip().lower() - for o in cond: - if not isinstance(o, dict): - continue - if (o.get("label") or "").startswith(("止盈", "止损")): - continue - role = _gate_tpsl_role_from_rule(o.get("gate_trigger_rule"), direction) - trig = o.get("trigger_price") - if not role or trig is None: - continue - try: - trig_f = float(trig) - except (TypeError, ValueError): - continue - prefix = "止损" if role == "sl" else "止盈" - o["label"] = f"{prefix} {trig_f:g}" - - -def attach_orders_to_positions(positions: list[dict], orders: list[dict]) -> None: - for p in positions: - sym = p.get("symbol") or "" - matched = [o for o in orders if symbols_match(sym, o.get("symbol") or "")] - cond = [o for o in matched if o.get("category") == "conditional"] - _enrich_gate_conditional_labels(cond, p.get("side") or "long") - p["conditional_orders"] = cond - p["regular_orders"] = [o for o in matched if o.get("category") != "conditional"] - - -def cancel_order( - ex: Any, - exchange_kind: str, - symbol: str, - order_id: str, - channel: str = "regular", -) -> None: - kind = (exchange_kind or "binance").lower() - ex.load_markets() - market = ex.market(symbol) - unified = market["symbol"] - ch = (channel or "regular").lower() - if kind == "binance" and ch == "algo": - contract_id = market.get("id") - if contract_id and hasattr(ex, "fapiPrivateDeleteAlgoOrder"): - ex.fapiPrivateDeleteAlgoOrder({"symbol": contract_id, "algoId": str(order_id)}) - return - params = None - if kind == "gate" and ch == "algo": - params = _gate_trigger_params(ex) - elif kind == "okx" and ch == "algo": - params = {"stop": True} - oid = _okx_algo_order_id(order_id) if kind == "okx" else str(order_id) - ex.cancel_order(oid, unified, params) - - -def cancel_orders_for_symbol( - ex: Any, - exchange_kind: str, - symbol: str, - *, - scope: str = "all", -) -> int: - """scope: all | conditional | limit""" - orders = list_open_orders(ex, exchange_kind, symbol) - if scope == "conditional": - orders = [o for o in orders if o.get("category") == "conditional"] - elif scope == "limit": - orders = [o for o in orders if o.get("category") != "conditional"] - n = 0 - for o in orders: - try: - cancel_order(ex, exchange_kind, symbol, o["id"], o.get("channel") or "regular") - n += 1 - except Exception: - pass - return n - - -def _binance_cancel_algo_open(ex: Any, symbol: str) -> None: - try: - market = ex.market(symbol) - cid = market.get("id") - if cid and hasattr(ex, "fapiPrivateDeleteAlgoOpenOrders"): - ex.fapiPrivateDeleteAlgoOpenOrders({"symbol": cid}) - except Exception: - pass - - -def _binance_trigger_params() -> dict[str, Any]: - wt = (os.getenv("BINANCE_TRIGGER_WORKING_TYPE") or "CONTRACT_PRICE").strip().upper() - if wt not in ("CONTRACT_PRICE", "MARK_PRICE"): - wt = "CONTRACT_PRICE" - return {"workingType": wt} - - -def _binance_place_tp_sl( - ex: Any, - symbol: str, - direction: str, - amount: float, - stop_loss: float, - take_profit: float, - *, - position_mode: str = "hedge", -) -> None: - ex.load_markets() - market = ex.market(symbol) - if not market.get("swap"): - raise RuntimeError("仅支持永续合约") - close_side = "sell" if direction == "long" else "buy" - amt = float(ex.amount_to_precision(symbol, float(amount))) - if amt <= 0: - raise RuntimeError("止盈止损:可平数量经精度舍入后为 0") - sl_px = ex.price_to_precision(symbol, float(stop_loss)) - tp_px = ex.price_to_precision(symbol, float(take_profit)) - common = dict(_binance_trigger_params()) - if (position_mode or "hedge").lower() in ("hedge", "dual", "double", "hedged"): - common["positionSide"] = "LONG" if direction == "long" else "SHORT" - last_err: Exception | None = None - for attempt in range(6): - try: - ex.create_order( - symbol, "STOP_MARKET", close_side, amt, None, dict(common, stopPrice=sl_px) - ) - time.sleep(0.05) - ex.create_order( - symbol, - "TAKE_PROFIT_MARKET", - close_side, - amt, - None, - dict(common, stopPrice=tp_px), - ) - return - except Exception as e: - last_err = e - cancel_orders_for_symbol(ex, "binance", symbol, scope="conditional") - _binance_cancel_algo_open(ex, symbol) - time.sleep(0.2 * (attempt + 1)) - raise RuntimeError(f"Binance 未接受止盈/止损:{last_err}") - - -def _okx_order_params( - direction: str, - *, - reduce_only: bool, - pos_mode: str, - td_mode: str, - for_algo_tpsl: bool = False, -) -> dict: - params: dict[str, Any] = {"tdMode": td_mode or "cross"} - if (pos_mode or "hedge").lower() in ("hedge", "long_short_mode", "dual"): - ps = "long" if direction == "long" else "short" - params["posSide"] = ps - params["positionSide"] = ps - # OKX 条件/OCO 算法单勿带 reduceOnly,否则可能被当市价减仓立即成交 - if reduce_only and not for_algo_tpsl: - params["reduceOnly"] = True - return params - - -def _okx_place_tp_sl( - ex: Any, - symbol: str, - direction: str, - amount: float, - stop_loss: float, - take_profit: float, - *, - pos_mode: str = "hedge", - td_mode: str = "cross", -) -> None: - """OKX 永续:一笔 OCO 算法单挂止盈+止损(勿 reduceOnly + 分两笔 market)。""" - ex.load_markets() - close_side = "sell" if direction == "long" else "buy" - amt = float(ex.amount_to_precision(symbol, float(amount))) - if amt <= 0: - raise RuntimeError("止盈止损:可平数量经精度舍入后为 0") - base = _okx_order_params( - direction, - reduce_only=False, - pos_mode=pos_mode, - td_mode=td_mode, - for_algo_tpsl=True, - ) - sl_px = ex.price_to_precision(symbol, float(stop_loss)) - tp_px = ex.price_to_precision(symbol, float(take_profit)) - order_params = { - **base, - "stopLossPrice": float(sl_px), - "takeProfitPrice": float(tp_px), - "tpOrdPx": "-1", - "slOrdPx": "-1", - } - last_err: Exception | None = None - for attempt in range(6): - try: - ex.create_order(symbol, "oco", close_side, amt, None, order_params) - return - except Exception as e: - last_err = e - cancel_orders_for_symbol(ex, "okx", symbol, scope="conditional") - time.sleep(0.2 * (attempt + 1)) - raise RuntimeError(f"OKX 未接受止盈/止损条件单:{last_err}") - - -def _gate_tpsl_env() -> tuple[bool, int, int, str]: - use_pos = (os.getenv("GATE_TPSL_USE_POSITION_ORDER") or "true").lower() in ("1", "true", "yes") - exp = int(os.getenv("GATE_TPSL_TRIGGER_EXPIRATION", str(7 * 86400))) - pt = int(os.getenv("GATE_TPSL_PRICE_TYPE", "0")) - if pt < 0 or pt > 2: - pt = 0 - pos_mode = (os.getenv("GATE_POS_MODE") or "hedge").strip().lower() - return use_pos, exp, pt, pos_mode - - -def _gate_place_tp_sl_position( - ex: Any, - symbol: str, - direction: str, - stop_loss: float, - take_profit: float, - *, - pos_mode: str, - price_type: int, - expiration: int, -) -> None: - ex.load_markets() - market = ex.market(symbol) - if not market.get("swap"): - raise RuntimeError("仅支持永续合约") - settle = market["settleId"] - contract = market["id"] - order_type = "close-long-position" if direction == "long" else "close-short-position" - close_side = "sell" if direction == "long" else "buy" - sl_rule, tp_rule = (2, 1) if close_side == "sell" else (1, 2) - initial: dict[str, Any] = { - "contract": contract, - "size": 0, - "price": "0", - "close": True, - "reduce_only": True, - "tif": "ioc", - "text": "api", - } - if pos_mode in ("hedge", "dual", "double"): - initial["auto_size"] = "close_long" if direction == "long" else "close_short" - # Gate API 1018:auto_size=close_long|close_short 时 initial.close 须为 false - initial["close"] = False - sl_s = ex.price_to_precision(symbol, float(stop_loss)) - tp_s = ex.price_to_precision(symbol, float(take_profit)) - - def _payload(trigger_price: str, rule: int) -> dict: - trig: dict[str, Any] = { - "strategy_type": 0, - "price_type": price_type, - "price": trigger_price, - "rule": rule, - } - if expiration > 0: - trig["expiration"] = expiration - return { - "settle": settle, - "initial": dict(initial), - "trigger": trig, - "order_type": order_type, - } - - last_err: Exception | None = None - for attempt in range(6): - try: - ex.privateFuturesPostSettlePriceOrders(_payload(sl_s, sl_rule)) - try: - ex.privateFuturesPostSettlePriceOrders(_payload(tp_s, tp_rule)) - except Exception: - cancel_orders_for_symbol(ex, "gate", symbol, scope="conditional") - raise - return - except Exception as e: - last_err = e - time.sleep(0.2 * (attempt + 1)) - raise RuntimeError(f"Gate 仓位类止盈/止损未接受:{last_err}") - - -def _gate_place_tp_sl_legacy( - ex: Any, - symbol: str, - direction: str, - amount: float, - stop_loss: float, - take_profit: float, -) -> None: - ex.load_markets() - close_side = "sell" if direction == "long" else "buy" - base = {"reduceOnly": True} - last_err: Exception | None = None - for attempt in range(6): - try: - ex.create_order( - symbol, - "market", - close_side, - amount, - None, - dict(base, stopLossPrice=float(stop_loss)), - ) - ex.create_order( - symbol, - "market", - close_side, - amount, - None, - dict(base, takeProfitPrice=float(take_profit)), - ) - return - except Exception as e: - last_err = e - time.sleep(0.2 * (attempt + 1)) - raise RuntimeError(f"Gate 条件止盈/止损未接受:{last_err}") - - -def _gate_td_mode_cross() -> bool: - td = (os.getenv("GATE_TD_MODE") or "cross").strip().lower() - return td in ("cross", "cross_margin") - - -def _gate_last_price(ex: Any, symbol: str) -> float | None: - ex.load_markets() - unified = ex.market(symbol)["symbol"] - try: - t = ex.fetch_ticker(unified) - except Exception: - return None - if not isinstance(t, dict): - return None - info = t.get("info") if isinstance(t.get("info"), dict) else {} - for key in ("last", "mark", "close", "index_price"): - v = t.get(key) if key in t else info.get(key) - try: - f = float(v) - if f > 0: - return f - except (TypeError, ValueError): - continue - return None - - -def _gate_clamp_tpsl_prices( - ex: Any, - symbol: str, - direction: str, - stop_loss: float, - take_profit: float, -) -> tuple[float, float]: - """ - Gate price_orders:空仓止损/多仓止盈 trigger>last;空仓止盈/多仓止损 trigger= last: - tp = float(ex.price_to_precision(unified, last * (1 - gap))) - else: - if sl >= last: - sl = float(ex.price_to_precision(unified, last * (1 - gap))) - if tp <= last: - tp = float(ex.price_to_precision(unified, last * (1 + gap))) - return sl, tp - - -def _gate_place_tp_sl( - ex: Any, - symbol: str, - direction: str, - amount: float, - stop_loss: float, - take_profit: float, -) -> None: - use_pos, exp, pt, pos_mode = _gate_tpsl_env() - pos_err: Exception | None = None - if use_pos: - try: - _gate_place_tp_sl_position( - ex, symbol, direction, stop_loss, take_profit, - pos_mode=pos_mode, price_type=pt, expiration=exp, - ) - return - except Exception as e: - pos_err = e - if _gate_td_mode_cross(): - raise RuntimeError( - f"Gate 仓位类止盈/止损未接受(全仓不支持 ccxt 条件单回退):{pos_err}" - ) from e - try: - _gate_place_tp_sl_legacy(ex, symbol, direction, amount, stop_loss, take_profit) - except Exception as legacy_err: - if pos_err is not None: - raise RuntimeError( - f"Gate 仓位类止盈/止损未接受:{pos_err};条件单回退亦失败:{legacy_err}" - ) from legacy_err - raise - - -def replace_position_tpsl( - ex: Any, - exchange_kind: str, - symbol: str, - direction: str, - amount: float, - stop_loss: float, - take_profit: float, -) -> dict[str, Any]: - """ - 先撤销该合约全部条件单,再挂止盈+止损。与四实例策略页逻辑对齐(读各目录 .env 中 GATE_/BINANCE_/OKX_ 参数)。 - """ - kind = (exchange_kind or "binance").lower() - direction = (direction or "long").strip().lower() - if direction not in ("long", "short"): - raise ValueError("direction 须为 long 或 short") - sl = float(stop_loss) - tp = float(take_profit) - if sl <= 0 or tp <= 0: - raise ValueError("止损、止盈价格须大于 0") - ex.load_markets() - cancelled = cancel_orders_for_symbol(ex, kind, symbol, scope="conditional") - if kind == "binance": - _binance_cancel_algo_open(ex, symbol) - time.sleep(0.08) - amt = float(amount) - if amt <= 0: - raise ValueError("持仓数量无效") - if kind == "binance": - pm = (os.getenv("BINANCE_POSITION_MODE") or "hedge").strip().lower() - _binance_place_tp_sl(ex, symbol, direction, amt, sl, tp, position_mode=pm) - elif kind == "okx": - pm = (os.getenv("OKX_POS_MODE") or "hedge").strip().lower() - td = (os.getenv("OKX_TD_MODE") or "cross").strip() - _okx_place_tp_sl(ex, symbol, direction, amt, sl, tp, pos_mode=pm, td_mode=td) - else: - sl, tp = _gate_clamp_tpsl_prices(ex, symbol, direction, sl, tp) - _gate_place_tp_sl(ex, symbol, direction, amt, sl, tp) - return { - "symbol": symbol, - "direction": direction, - "amount": amt, - "stop_loss": sl, - "take_profit": tp, - "cancelled_conditional": cancelled, - } +""" +中控子代理:拉取交易所挂单/条件单并规范化展示;撤销单笔或按合约批量撤销;挂止盈止损(先撤条件单再挂)。 +""" +from __future__ import annotations + +import os +import time +from typing import Any + +from lib.exchange.okx_orders_lib import fetch_okx_all_open_orders + + +def _coerce_float(*values) -> float | None: + for v in values: + if v is None or v == "": + continue + try: + return float(v) + except (TypeError, ValueError): + continue + return None + + +def _symbol_base_coin(symbol: str) -> str: + """ZEC/USDT:USDT、ZEC-USDT-SWAP 等统一为标的币 ZEC。""" + s = (symbol or "").strip().upper() + if not s: + return "" + if "-SWAP" in s: + s = s.replace("-SWAP", "") + if "-" in s: + return s.split("-", 1)[0] + if "/" in s: + return s.split("/", 1)[0] + if ":" in s: + return s.split(":", 1)[0] + return s + + +def symbols_match(position_symbol: str, order_symbol: str) -> bool: + a = (position_symbol or "").strip().upper() + b = (order_symbol or "").strip().upper() + if not a or not b: + return False + if a == b: + return True + ba, bb = _symbol_base_coin(a), _symbol_base_coin(b) + if ba and bb and ba == bb: + return True + for suf in (":USDT", "/USDT:USDT", "/USDT"): + a2 = a.replace(suf, "") + b2 = b.replace(suf, "") + if f"{a2}/USDT" == b or f"{a2}/USDT:USDT" == b: + return True + if f"{b2}/USDT" == a or f"{b2}/USDT:USDT" == a: + return True + if a2 == b2: + return True + return False + + +def _order_type_str(order: dict) -> str: + info = order.get("info") or {} + if isinstance(info, dict): + for key in ("orderType", "type", "origType", "algoType", "ordType"): + val = info.get(key) + if val: + return str(val).upper() + return str(order.get("type") or "").upper() + + +def _is_conditional_type(typ: str) -> bool: + t = (typ or "").upper() + if not t: + return False + keys = ("STOP", "TAKE_PROFIT", "TRAIL", "TRIGGER", "CONDITIONAL", "OCO") + return any(k in t for k in keys) + + +def _order_label(typ: str, side: str, reduce_only: bool | None) -> str: + t = (typ or "").upper() + side_l = (side or "").lower() + parts = [] + if "TAKE_PROFIT" in t: + parts.append("止盈") + elif "STOP" in t: + parts.append("止损") + elif "LIMIT" in t: + parts.append("限价") + elif "MARKET" in t: + parts.append("市价") + else: + parts.append(typ or "委托") + if side_l == "buy": + parts.append("买入") + elif side_l == "sell": + parts.append("卖出") + if reduce_only: + parts.append("·只减仓") + return " ".join(parts) + + +def _normalize_raw_order(order: dict, *, channel: str) -> dict[str, Any] | None: + if not isinstance(order, dict): + return None + info = order.get("info") or {} + if not isinstance(info, dict): + info = {} + oid = order.get("id") or info.get("algoId") or info.get("orderId") or info.get("ordId") + if oid is None: + return None + sym = str(order.get("symbol") or info.get("symbol") or info.get("instId") or "") + typ = _order_type_str(order) + side = str(order.get("side") or info.get("side") or "").lower() + reduce_only = order.get("reduceOnly") + if reduce_only is None: + reduce_only = info.get("reduceOnly") + try: + reduce_only = bool(reduce_only) if reduce_only is not None else None + except (TypeError, ValueError): + reduce_only = None + sl_trig = _coerce_float(info.get("slTriggerPx"), order.get("stopLossPrice")) + tp_trig = _coerce_float(info.get("tpTriggerPx"), order.get("takeProfitPrice")) + trig = _coerce_float( + order.get("stopPrice"), + order.get("triggerPrice"), + info.get("triggerPrice"), + info.get("stopPrice"), + info.get("triggerPx"), + sl_trig, + tp_trig, + ) + price = _coerce_float(order.get("price"), info.get("price"), info.get("ordPx")) + amt = _coerce_float(order.get("amount"), order.get("remaining"), info.get("quantity"), info.get("origQty"), info.get("sz")) + category = "conditional" if _is_conditional_type(typ) or channel == "algo" else "limit" + label = _order_label(typ, side, reduce_only) + if sl_trig is not None and tp_trig is not None: + label = f"止盈止损 SL={sl_trig:g} TP={tp_trig:g}" + elif sl_trig is not None: + label = f"止损 {sl_trig:g}" + elif tp_trig is not None: + label = f"止盈 {tp_trig:g}" + return { + "id": str(oid), + "symbol": sym, + "channel": channel, + "category": category, + "label": label, + "type": typ, + "side": side, + "amount": amt, + "trigger_price": trig, + "price": price, + "reduce_only": reduce_only, + "status": str(order.get("status") or info.get("status") or "open"), + } + + +def _okx_normalize_orders(raw: dict, channel: str) -> list[dict[str, Any]]: + """OKX 算法单常一笔同时含 SL+TP,拆成两条供中控「交易所止盈止损」展示。""" + n = _normalize_raw_order(dict(raw), channel=channel) + if not n: + return [] + info = raw.get("info") or {} + if not isinstance(info, dict): + info = {} + sl_trig = _coerce_float(info.get("slTriggerPx"), raw.get("stopLossPrice")) + tp_trig = _coerce_float(info.get("tpTriggerPx"), raw.get("takeProfitPrice")) + if sl_trig is None or tp_trig is None or sl_trig == tp_trig: + return [n] + base_id = n["id"] + rows: list[dict[str, Any]] = [] + for role, px, lbl in ( + ("sl", sl_trig, f"止损 {sl_trig:g}"), + ("tp", tp_trig, f"止盈 {tp_trig:g}"), + ): + row = dict(n) + row["id"] = f"{base_id}:{role}" + row["algo_id"] = base_id + row["label"] = lbl + row["trigger_price"] = px + row["category"] = "conditional" + row["channel"] = channel + rows.append(row) + return rows + + +def _okx_algo_order_id(order_id: str) -> str: + oid = str(order_id or "") + if ":" in oid: + return oid.split(":", 1)[0] + return oid + + +def _binance_list(ex: Any, symbol: str | None) -> list[dict]: + ex.load_markets() + out: list[dict] = [] + symbols: list[str] = [] + if symbol: + try: + symbols = [ex.market(symbol)["symbol"]] + except Exception: + symbols = [symbol] + else: + symbols = [] + try: + for p in ex.fetch_positions() or []: + sym = p.get("symbol") + if sym: + symbols.append(sym) + except Exception: + pass + if symbol and not symbols: + symbols = [symbol] + + def collect(ex_sym: str) -> None: + market = ex.market(ex_sym) + contract_id = market.get("id") + try: + for o in ex.fetch_open_orders(ex_sym) or []: + item = dict(o) + item["_channel"] = "regular" + n = _normalize_raw_order(item, channel="regular") + if n: + out.append(n) + except Exception: + pass + try: + if contract_id and hasattr(ex, "fapiPrivateGetOpenAlgoOrders"): + raw = ex.fapiPrivateGetOpenAlgoOrders({"symbol": contract_id}) + items = raw if isinstance(raw, list) else (raw.get("orders") or raw.get("data") or []) + for info in items or []: + if not isinstance(info, dict): + continue + wrapped = { + "id": info.get("algoId") or info.get("orderId"), + "symbol": ex_sym, + "info": info, + "type": info.get("orderType") or info.get("type"), + "side": (info.get("side") or "").lower(), + "amount": info.get("quantity") or info.get("origQty"), + "stopPrice": info.get("triggerPrice") or info.get("stopPrice"), + "reduceOnly": info.get("reduceOnly"), + } + n = _normalize_raw_order(wrapped, channel="algo") + if n: + out.append(n) + except Exception: + pass + + if symbols: + seen = set() + for s in symbols: + if s in seen: + continue + seen.add(s) + collect(s) + return out + + +def _okx_list(ex: Any, symbol: str | None) -> list[dict]: + ex.load_markets() + out: list[dict] = [] + symbols: list[str] = [] + if symbol: + try: + symbols = [ex.market(symbol)["symbol"]] + except Exception: + symbols = [symbol] + else: + try: + for p in ex.fetch_positions() or []: + sym = p.get("symbol") + if sym: + symbols.append(sym) + except Exception: + pass + if symbol and not symbols: + symbols = [symbol] + seen: set[tuple[str, str]] = set() + for sym in symbols: + try: + for o in fetch_okx_all_open_orders(ex, sym): + ch = "algo" if _is_conditional_type(_order_type_str(o)) else "regular" + for n in _okx_normalize_orders(dict(o), channel=ch): + key = (n["id"], n.get("channel") or ch) + if key in seen: + continue + seen.add(key) + out.append(n) + except Exception: + pass + return out + + +def _gate_extract_trigger_rule(info: dict) -> int | None: + if not isinstance(info, dict): + return None + trig = info.get("trigger") + if isinstance(trig, dict) and trig.get("rule") is not None: + try: + return int(trig["rule"]) + except (TypeError, ValueError): + pass + try: + return int(info.get("rule")) + except (TypeError, ValueError): + return None + + +def _gate_tpsl_role_from_rule(rule: int | None, direction: str) -> str | None: + if rule is None: + return None + d = (direction or "long").strip().lower() + if d == "long": + return "sl" if rule == 2 else ("tp" if rule == 1 else None) + return "sl" if rule == 1 else ("tp" if rule == 2 else None) + + +def _gate_trigger_params(ex: Any) -> dict: + p = {"type": "swap", "trigger": True} + try: + ex.load_unified_status() + if ex.options.get("unifiedAccount"): + p["unifiedAccount"] = True + except Exception: + pass + return p + + +def _gate_list(ex: Any, symbol: str | None) -> list[dict]: + ex.load_markets() + out: list[dict] = [] + symbols: list[str] = [] + if symbol: + try: + symbols = [ex.market(symbol)["symbol"]] + except Exception: + symbols = [symbol] + else: + try: + for p in ex.fetch_positions() or []: + sym = p.get("symbol") + if sym: + symbols.append(sym) + except Exception: + pass + if symbol and not symbols: + symbols = [symbol] + trig_params = _gate_trigger_params(ex) + seen = set() + for sym in symbols: + if sym in seen: + continue + seen.add(sym) + try: + for o in ex.fetch_open_orders(sym) or []: + n = _normalize_raw_order(dict(o), channel="regular") + if n: + out.append(n) + except Exception: + pass + try: + for o in ex.fetch_open_orders(sym, params=trig_params) or []: + item = dict(o) + item["type"] = item.get("type") or "trigger" + n = _normalize_raw_order(item, channel="algo") + if n: + info = o.get("info") if isinstance(o.get("info"), dict) else {} + rule = _gate_extract_trigger_rule(info) + if rule is not None: + n["gate_trigger_rule"] = rule + out.append(n) + except Exception: + pass + return out + + +def list_open_orders(ex: Any, exchange_kind: str, symbol: str | None = None) -> list[dict]: + kind = (exchange_kind or "binance").lower() + if kind == "binance": + orders = _binance_list(ex, symbol) + elif kind == "okx": + orders = _okx_list(ex, symbol) + else: + orders = _gate_list(ex, symbol) + if symbol: + orders = [o for o in orders if symbols_match(symbol, o.get("symbol") or "")] + # 去重 id+channel + seen: set[tuple[str, str]] = set() + uniq: list[dict] = [] + for o in orders: + key = (o["id"], o["channel"]) + if key in seen: + continue + seen.add(key) + uniq.append(o) + return uniq + + +def _enrich_gate_conditional_labels(cond: list[dict], side: str) -> None: + """Gate 仓位类触发单在 ccxt 中常显示为「市价·只减仓」,按 trigger.rule 标为止盈/止损。""" + direction = (side or "long").strip().lower() + for o in cond: + if not isinstance(o, dict): + continue + if (o.get("label") or "").startswith(("止盈", "止损")): + continue + role = _gate_tpsl_role_from_rule(o.get("gate_trigger_rule"), direction) + trig = o.get("trigger_price") + if not role or trig is None: + continue + try: + trig_f = float(trig) + except (TypeError, ValueError): + continue + prefix = "止损" if role == "sl" else "止盈" + o["label"] = f"{prefix} {trig_f:g}" + + +def attach_orders_to_positions(positions: list[dict], orders: list[dict]) -> None: + for p in positions: + sym = p.get("symbol") or "" + matched = [o for o in orders if symbols_match(sym, o.get("symbol") or "")] + cond = [o for o in matched if o.get("category") == "conditional"] + _enrich_gate_conditional_labels(cond, p.get("side") or "long") + p["conditional_orders"] = cond + p["regular_orders"] = [o for o in matched if o.get("category") != "conditional"] + + +def cancel_order( + ex: Any, + exchange_kind: str, + symbol: str, + order_id: str, + channel: str = "regular", +) -> None: + kind = (exchange_kind or "binance").lower() + ex.load_markets() + market = ex.market(symbol) + unified = market["symbol"] + ch = (channel or "regular").lower() + if kind == "binance" and ch == "algo": + contract_id = market.get("id") + if contract_id and hasattr(ex, "fapiPrivateDeleteAlgoOrder"): + ex.fapiPrivateDeleteAlgoOrder({"symbol": contract_id, "algoId": str(order_id)}) + return + params = None + if kind == "gate" and ch == "algo": + params = _gate_trigger_params(ex) + elif kind == "okx" and ch == "algo": + params = {"stop": True} + oid = _okx_algo_order_id(order_id) if kind == "okx" else str(order_id) + ex.cancel_order(oid, unified, params) + + +def cancel_orders_for_symbol( + ex: Any, + exchange_kind: str, + symbol: str, + *, + scope: str = "all", +) -> int: + """scope: all | conditional | limit""" + orders = list_open_orders(ex, exchange_kind, symbol) + if scope == "conditional": + orders = [o for o in orders if o.get("category") == "conditional"] + elif scope == "limit": + orders = [o for o in orders if o.get("category") != "conditional"] + n = 0 + for o in orders: + try: + cancel_order(ex, exchange_kind, symbol, o["id"], o.get("channel") or "regular") + n += 1 + except Exception: + pass + return n + + +def _binance_cancel_algo_open(ex: Any, symbol: str) -> None: + try: + market = ex.market(symbol) + cid = market.get("id") + if cid and hasattr(ex, "fapiPrivateDeleteAlgoOpenOrders"): + ex.fapiPrivateDeleteAlgoOpenOrders({"symbol": cid}) + except Exception: + pass + + +def _binance_trigger_params() -> dict[str, Any]: + wt = (os.getenv("BINANCE_TRIGGER_WORKING_TYPE") or "CONTRACT_PRICE").strip().upper() + if wt not in ("CONTRACT_PRICE", "MARK_PRICE"): + wt = "CONTRACT_PRICE" + return {"workingType": wt} + + +def _binance_place_tp_sl( + ex: Any, + symbol: str, + direction: str, + amount: float, + stop_loss: float, + take_profit: float, + *, + position_mode: str = "hedge", +) -> None: + ex.load_markets() + market = ex.market(symbol) + if not market.get("swap"): + raise RuntimeError("仅支持永续合约") + close_side = "sell" if direction == "long" else "buy" + amt = float(ex.amount_to_precision(symbol, float(amount))) + if amt <= 0: + raise RuntimeError("止盈止损:可平数量经精度舍入后为 0") + sl_px = ex.price_to_precision(symbol, float(stop_loss)) + tp_px = ex.price_to_precision(symbol, float(take_profit)) + common = dict(_binance_trigger_params()) + if (position_mode or "hedge").lower() in ("hedge", "dual", "double", "hedged"): + common["positionSide"] = "LONG" if direction == "long" else "SHORT" + last_err: Exception | None = None + for attempt in range(6): + try: + ex.create_order( + symbol, "STOP_MARKET", close_side, amt, None, dict(common, stopPrice=sl_px) + ) + time.sleep(0.05) + ex.create_order( + symbol, + "TAKE_PROFIT_MARKET", + close_side, + amt, + None, + dict(common, stopPrice=tp_px), + ) + return + except Exception as e: + last_err = e + cancel_orders_for_symbol(ex, "binance", symbol, scope="conditional") + _binance_cancel_algo_open(ex, symbol) + time.sleep(0.2 * (attempt + 1)) + raise RuntimeError(f"Binance 未接受止盈/止损:{last_err}") + + +def _okx_order_params( + direction: str, + *, + reduce_only: bool, + pos_mode: str, + td_mode: str, + for_algo_tpsl: bool = False, +) -> dict: + params: dict[str, Any] = {"tdMode": td_mode or "cross"} + if (pos_mode or "hedge").lower() in ("hedge", "long_short_mode", "dual"): + ps = "long" if direction == "long" else "short" + params["posSide"] = ps + params["positionSide"] = ps + # OKX 条件/OCO 算法单勿带 reduceOnly,否则可能被当市价减仓立即成交 + if reduce_only and not for_algo_tpsl: + params["reduceOnly"] = True + return params + + +def _okx_place_tp_sl( + ex: Any, + symbol: str, + direction: str, + amount: float, + stop_loss: float, + take_profit: float, + *, + pos_mode: str = "hedge", + td_mode: str = "cross", +) -> None: + """OKX 永续:一笔 OCO 算法单挂止盈+止损(勿 reduceOnly + 分两笔 market)。""" + ex.load_markets() + close_side = "sell" if direction == "long" else "buy" + amt = float(ex.amount_to_precision(symbol, float(amount))) + if amt <= 0: + raise RuntimeError("止盈止损:可平数量经精度舍入后为 0") + base = _okx_order_params( + direction, + reduce_only=False, + pos_mode=pos_mode, + td_mode=td_mode, + for_algo_tpsl=True, + ) + sl_px = ex.price_to_precision(symbol, float(stop_loss)) + tp_px = ex.price_to_precision(symbol, float(take_profit)) + order_params = { + **base, + "stopLossPrice": float(sl_px), + "takeProfitPrice": float(tp_px), + "tpOrdPx": "-1", + "slOrdPx": "-1", + } + last_err: Exception | None = None + for attempt in range(6): + try: + ex.create_order(symbol, "oco", close_side, amt, None, order_params) + return + except Exception as e: + last_err = e + cancel_orders_for_symbol(ex, "okx", symbol, scope="conditional") + time.sleep(0.2 * (attempt + 1)) + raise RuntimeError(f"OKX 未接受止盈/止损条件单:{last_err}") + + +def _gate_tpsl_env() -> tuple[bool, int, int, str]: + use_pos = (os.getenv("GATE_TPSL_USE_POSITION_ORDER") or "true").lower() in ("1", "true", "yes") + exp = int(os.getenv("GATE_TPSL_TRIGGER_EXPIRATION", str(7 * 86400))) + pt = int(os.getenv("GATE_TPSL_PRICE_TYPE", "0")) + if pt < 0 or pt > 2: + pt = 0 + pos_mode = (os.getenv("GATE_POS_MODE") or "hedge").strip().lower() + return use_pos, exp, pt, pos_mode + + +def _gate_place_tp_sl_position( + ex: Any, + symbol: str, + direction: str, + stop_loss: float, + take_profit: float, + *, + pos_mode: str, + price_type: int, + expiration: int, +) -> None: + ex.load_markets() + market = ex.market(symbol) + if not market.get("swap"): + raise RuntimeError("仅支持永续合约") + settle = market["settleId"] + contract = market["id"] + order_type = "close-long-position" if direction == "long" else "close-short-position" + close_side = "sell" if direction == "long" else "buy" + sl_rule, tp_rule = (2, 1) if close_side == "sell" else (1, 2) + initial: dict[str, Any] = { + "contract": contract, + "size": 0, + "price": "0", + "close": True, + "reduce_only": True, + "tif": "ioc", + "text": "api", + } + if pos_mode in ("hedge", "dual", "double"): + initial["auto_size"] = "close_long" if direction == "long" else "close_short" + # Gate API 1018:auto_size=close_long|close_short 时 initial.close 须为 false + initial["close"] = False + sl_s = ex.price_to_precision(symbol, float(stop_loss)) + tp_s = ex.price_to_precision(symbol, float(take_profit)) + + def _payload(trigger_price: str, rule: int) -> dict: + trig: dict[str, Any] = { + "strategy_type": 0, + "price_type": price_type, + "price": trigger_price, + "rule": rule, + } + if expiration > 0: + trig["expiration"] = expiration + return { + "settle": settle, + "initial": dict(initial), + "trigger": trig, + "order_type": order_type, + } + + last_err: Exception | None = None + for attempt in range(6): + try: + ex.privateFuturesPostSettlePriceOrders(_payload(sl_s, sl_rule)) + try: + ex.privateFuturesPostSettlePriceOrders(_payload(tp_s, tp_rule)) + except Exception: + cancel_orders_for_symbol(ex, "gate", symbol, scope="conditional") + raise + return + except Exception as e: + last_err = e + time.sleep(0.2 * (attempt + 1)) + raise RuntimeError(f"Gate 仓位类止盈/止损未接受:{last_err}") + + +def _gate_place_tp_sl_legacy( + ex: Any, + symbol: str, + direction: str, + amount: float, + stop_loss: float, + take_profit: float, +) -> None: + ex.load_markets() + close_side = "sell" if direction == "long" else "buy" + base = {"reduceOnly": True} + last_err: Exception | None = None + for attempt in range(6): + try: + ex.create_order( + symbol, + "market", + close_side, + amount, + None, + dict(base, stopLossPrice=float(stop_loss)), + ) + ex.create_order( + symbol, + "market", + close_side, + amount, + None, + dict(base, takeProfitPrice=float(take_profit)), + ) + return + except Exception as e: + last_err = e + time.sleep(0.2 * (attempt + 1)) + raise RuntimeError(f"Gate 条件止盈/止损未接受:{last_err}") + + +def _gate_td_mode_cross() -> bool: + td = (os.getenv("GATE_TD_MODE") or "cross").strip().lower() + return td in ("cross", "cross_margin") + + +def _gate_last_price(ex: Any, symbol: str) -> float | None: + ex.load_markets() + unified = ex.market(symbol)["symbol"] + try: + t = ex.fetch_ticker(unified) + except Exception: + return None + if not isinstance(t, dict): + return None + info = t.get("info") if isinstance(t.get("info"), dict) else {} + for key in ("last", "mark", "close", "index_price"): + v = t.get(key) if key in t else info.get(key) + try: + f = float(v) + if f > 0: + return f + except (TypeError, ValueError): + continue + return None + + +def _gate_clamp_tpsl_prices( + ex: Any, + symbol: str, + direction: str, + stop_loss: float, + take_profit: float, +) -> tuple[float, float]: + """ + Gate price_orders:空仓止损/多仓止盈 trigger>last;空仓止盈/多仓止损 trigger= last: + tp = float(ex.price_to_precision(unified, last * (1 - gap))) + else: + if sl >= last: + sl = float(ex.price_to_precision(unified, last * (1 - gap))) + if tp <= last: + tp = float(ex.price_to_precision(unified, last * (1 + gap))) + return sl, tp + + +def _gate_place_tp_sl( + ex: Any, + symbol: str, + direction: str, + amount: float, + stop_loss: float, + take_profit: float, +) -> None: + use_pos, exp, pt, pos_mode = _gate_tpsl_env() + pos_err: Exception | None = None + if use_pos: + try: + _gate_place_tp_sl_position( + ex, symbol, direction, stop_loss, take_profit, + pos_mode=pos_mode, price_type=pt, expiration=exp, + ) + return + except Exception as e: + pos_err = e + if _gate_td_mode_cross(): + raise RuntimeError( + f"Gate 仓位类止盈/止损未接受(全仓不支持 ccxt 条件单回退):{pos_err}" + ) from e + try: + _gate_place_tp_sl_legacy(ex, symbol, direction, amount, stop_loss, take_profit) + except Exception as legacy_err: + if pos_err is not None: + raise RuntimeError( + f"Gate 仓位类止盈/止损未接受:{pos_err};条件单回退亦失败:{legacy_err}" + ) from legacy_err + raise + + +def replace_position_tpsl( + ex: Any, + exchange_kind: str, + symbol: str, + direction: str, + amount: float, + stop_loss: float, + take_profit: float, +) -> dict[str, Any]: + """ + 先撤销该合约全部条件单,再挂止盈+止损。与四实例策略页逻辑对齐(读各目录 .env 中 GATE_/BINANCE_/OKX_ 参数)。 + """ + kind = (exchange_kind or "binance").lower() + direction = (direction or "long").strip().lower() + if direction not in ("long", "short"): + raise ValueError("direction 须为 long 或 short") + sl = float(stop_loss) + tp = float(take_profit) + if sl <= 0 or tp <= 0: + raise ValueError("止损、止盈价格须大于 0") + ex.load_markets() + cancelled = cancel_orders_for_symbol(ex, kind, symbol, scope="conditional") + if kind == "binance": + _binance_cancel_algo_open(ex, symbol) + time.sleep(0.08) + amt = float(amount) + if amt <= 0: + raise ValueError("持仓数量无效") + if kind == "binance": + pm = (os.getenv("BINANCE_POSITION_MODE") or "hedge").strip().lower() + _binance_place_tp_sl(ex, symbol, direction, amt, sl, tp, position_mode=pm) + elif kind == "okx": + pm = (os.getenv("OKX_POS_MODE") or "hedge").strip().lower() + td = (os.getenv("OKX_TD_MODE") or "cross").strip() + _okx_place_tp_sl(ex, symbol, direction, amt, sl, tp, pos_mode=pm, td_mode=td) + else: + sl, tp = _gate_clamp_tpsl_prices(ex, symbol, direction, sl, tp) + _gate_place_tp_sl(ex, symbol, direction, amt, sl, tp) + return { + "symbol": symbol, + "direction": direction, + "amount": amt, + "stop_loss": sl, + "take_profit": tp, + "cancelled_conditional": cancelled, + } diff --git a/manual_trading_hub/hub.py b/manual_trading_hub/hub.py index 9f701da..95eedb7 100644 --- a/manual_trading_hub/hub.py +++ b/manual_trading_hub/hub.py @@ -1,2776 +1,2776 @@ -""" -多账户交易中控:监控区 / 系统设置。 -聚合各实例监控数据与子代理 /status;下单请在各实例网页操作。 -""" -from __future__ import annotations - -import asyncio -import os -import sys -from contextlib import asynccontextmanager -from pathlib import Path - -_REPO_ROOT = Path(__file__).resolve().parent.parent -if str(_REPO_ROOT) not in sys.path: - sys.path.insert(0, str(_REPO_ROOT)) - -from hub_kline_store import format_ohlcv_detail, resolve_chart_bars, retention_days -from hub_ohlcv_lib import ( - CHART_TIMEFRAME_ORDER, - CHART_TIMEFRAMES, - bar_limit_for_timeframe, - chart_chunk_limit, - chart_initial_limit, - chart_memory_cap, - retention_policy_meta, -) -from hub_volume_rank_lib import ( - TOP_N_DEFAULT, - _exchange_rank_row_stale, - cache_needs_refresh, - format_volume_quote, - get_cached_rank, - load_volume_rank_cache, - merge_exchange_rank, - rank_date_label, - save_volume_rank_cache, - seconds_until_next_reset, - volume_rank_reset_hour, -) -from hub_symbol_archive_lib import ( - ARCHIVE_DEFAULT_TIMEFRAME, - ARCHIVE_QUOTES_MAX, - ARCHIVE_SEED_LOOKBACK_DAYS, - ARCHIVE_SYNC_INTERVAL_SEC, - ARCHIVE_TIMEFRAMES, - ARCHIVE_TRADE_DAYS, - ARCHIVE_TRADE_LIMIT, - ARCHIVE_VISIBLE_BARS_DEFAULT, - create_review_quote, - delete_review_quote, - init_db as init_archive_db, - list_daily_trades, - list_archive_calendar, - list_review_quotes, - list_symbol_rows, - load_symbol_trades, - parse_wall_clock_ms, - resolve_archive_chart, - sync_exchange_symbol_archives, - today_trading_day, - update_review_quote, - upsert_trade_overlay, -) -from hub_entry_plan_lib import ( - compute_entry_plan_stats, - create_entry_plan, - delete_entry_plan, - get_entry_plan, - init_db as init_entry_plan_db, - list_entry_plans, - meta_payload as entry_plan_meta_payload, - update_entry_plan, -) -from hub_macro_calendar_lib import ( - MACRO_EVENT_LABELS, - MACRO_EVENT_TYPES, - create_event as create_macro_event, - delete_event as delete_macro_event, - init_db as init_macro_calendar_db, - list_active_alerts, - list_events as list_macro_events, - update_event as update_macro_event, -) -from env_load import load_hub_dotenv - -load_hub_dotenv() - -import httpx -from fastapi import Body, FastAPI, HTTPException, Request -from fastapi.responses import FileResponse, JSONResponse -from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel, Field - -from settings_store import ( - enabled_exchanges, - env_force_disabled_ids, - load_settings, - normalize_display_prefs, - normalize_supervisor_settings, - save_settings, -) -from hub_web_auth import ( - SESSION_COOKIE, - SESSION_MAX_AGE_SEC, - clear_session_cookie, - cookie_secure_for_request, - create_session_token, - embed_allowed, - embed_frame_ancestors, - is_public_path, - password_required, - set_session_cookie, - validate_session_token, - expected_username, - verify_credentials, -) -from hub_sso import HUB_SSO_TTL_SEC, mint_hub_sso_token, safe_next_path -from url_public import browser_url, default_review_url, public_origin -from urllib.parse import urlencode - -from hub_board_cache import HUB_BOARD_POLL_INTERVAL, board_store -from hub_dashboard_cache import dashboard_store -from hub_dashboard import DASHBOARD_POLL_INTERVAL_SEC -from hub_supervisor_cache import supervisor_store -from hub_supervisor_lib import process_supervisor_tick, set_supervisor_notify_hook -from hub_ai.supervisor import make_supervisor_ai_reply_fn -from hub_ai.config import trading_day_reset_hour -from hub_chart_cache import ( - HUB_CHART_POLL_INTERVAL, - HUB_CHART_WATCH_TTL_SEC, - chart_poll_store, - parse_series_key, -) - -try: - from exchange_orders import symbols_match as _symbols_match -except ImportError: - - def _symbols_match(position_symbol: str, order_symbol: str) -> bool: - a = (position_symbol or "").strip().upper() - b = (order_symbol or "").strip().upper() - return bool(a and b and a == b) - -HUB_HOST = os.getenv("HUB_HOST", "0.0.0.0") -HUB_PORT = int(os.getenv("HUB_PORT", "5100")) -HUB_BRIDGE_TOKEN = (os.getenv("HUB_BRIDGE_TOKEN") or os.getenv("CONTROL_TOKEN") or "").strip() -_trust_raw = (os.getenv("HUB_TRUST_LAN", "true") or "").strip().lower() -HUB_TRUST_LAN = _trust_raw not in ("0", "false", "no", "off") -_allow_pub_raw = (os.getenv("HUB_ALLOW_PUBLIC") or "").strip().lower() -# 云服务器 + 域名反代时设为 true:不做 IP 限制,仅靠 HUB_PASSWORD / 登录页保护 -HUB_ALLOW_PUBLIC = _allow_pub_raw in ("1", "true", "yes", "on") -DIR = Path(__file__).resolve().parent -HUB_BUILD = "20260607-hub-archive" -_archive_sync_stop: asyncio.Event | None = None -_archive_sync_task: asyncio.Task | None = None -_last_archive_sync: dict | None = None -_volume_rank_stop: asyncio.Event | None = None -_volume_rank_task: asyncio.Task | None = None -_volume_rank_cache: dict | None = None -HUB_AGENT_TIMEOUT = float(os.getenv("HUB_AGENT_TIMEOUT", "8")) -HUB_FLASK_TIMEOUT = float(os.getenv("HUB_FLASK_TIMEOUT", "10")) -HUB_BOARD_TIMEOUT = float(os.getenv("HUB_BOARD_TIMEOUT", "45")) -_board_key_prices_raw = (os.getenv("HUB_BOARD_KEY_PRICES", "true") or "").strip().lower() -HUB_BOARD_KEY_PRICES = _board_key_prices_raw in ("1", "true", "yes", "on") - - -def _is_local(host: str | None) -> bool: - if not host: - return False - h = host.lower() - return h in ("127.0.0.1", "::1", "localhost") or h.startswith("::ffff:127.0.0.1") - - -def _ipv4_rfc1918_private(host: str) -> bool: - h = host.lower() - if h.startswith("::ffff:"): - h = h[7:] - parts = h.split(".") - if len(parts) != 4: - return False - try: - a, b, c, d = (int(x) for x in parts) - except ValueError: - return False - if any(x < 0 or x > 255 for x in (a, b, c, d)): - return False - if a == 10: - return True - if a == 172 and 16 <= b <= 31: - return True - if a == 192 and b == 168: - return True - return False - - -def _client_allowed(host: str | None) -> bool: - if _is_local(host): - return True - if HUB_TRUST_LAN and host and _ipv4_rfc1918_private(host): - return True - return False - - -def _hub_headers() -> dict[str, str]: - if not HUB_BRIDGE_TOKEN: - return {} - return {"X-Hub-Token": HUB_BRIDGE_TOKEN} - - -def _agent_headers() -> dict[str, str]: - if not HUB_BRIDGE_TOKEN: - return {} - return {"X-Control-Token": HUB_BRIDGE_TOKEN} - - -def _find_exchange(ex_id: str) -> dict | None: - for ex in load_settings().get("exchanges") or []: - if str(ex.get("id")) == str(ex_id): - return ex - return None - - -async def _run_chart_poll() -> dict: - keys = chart_poll_store.active_series_keys() - if not keys: - return {"ok": True, "series_count": 0, "polled": 0} - polled = 0 - errors: list[str] = [] - for key in keys: - parsed = parse_series_key(key) - if not parsed: - continue - ex_k, sym, tf = parsed - ex = _find_exchange_by_key(ex_k) - if not ex or not ex.get("enabled"): - continue - - ex_ref = ex - sym_ref = sym - tf_ref = tf - - def remote_fetch(**kwargs) -> dict: - tf_use = kwargs.get("timeframe") or tf_ref - return _fetch_instance_ohlcv_sync( - ex_ref, - symbol=kwargs.get("symbol") or sym_ref, - timeframe=tf_use, - since_ms=kwargs.get("since_ms"), - limit=int(kwargs.get("limit") or bar_limit_for_timeframe(tf_use)), - ) - - try: - result = await asyncio.to_thread( - resolve_chart_bars, - ex_k, - sym, - tf, - remote_fetch, - force_refresh=False, - tail_refresh=True, - ) - polled += 1 - chart_poll_store.note_series_result( - ex_k, - sym, - tf, - ok=bool(result.get("ok")), - fetched=int(result.get("fetched") or 0), - error=None if result.get("ok") else str(result.get("msg") or "poll_failed"), - candles=result.get("candles") if result.get("ok") else None, - price_tick=result.get("price_tick"), - ) - if not result.get("ok"): - errors.append(f"{key}:{result.get('msg')}") - except Exception as e: - chart_poll_store.note_series_result(ex_k, sym, tf, ok=False, error=str(e)) - errors.append(f"{key}:{e}") - out: dict = {"ok": True, "series_count": len(keys), "polled": polled} - if errors: - out["errors"] = errors[:8] - return out - - -async def _run_board_aggregate() -> dict: - try: - body = await asyncio.wait_for(_build_monitor_board_payload(), timeout=HUB_BOARD_TIMEOUT) - try: - from hub_fund_history_lib import record_fund_snapshot_from_board - - await asyncio.to_thread(record_fund_snapshot_from_board, body.get("rows") or []) - except Exception: - pass - return {"ok": True, **body} - except asyncio.TimeoutError: - return { - "ok": False, - "rows": [], - "error": "board_timeout", - "msg": ( - f"监控聚合超过 {int(HUB_BOARD_TIMEOUT)} 秒。" - "请检查子代理/Flask,或设 HUB_BOARD_KEY_PRICES=false、缩短 HUB_FLASK_TIMEOUT" - ), - "updated_at": __import__("datetime").datetime.now().isoformat(timespec="seconds"), - } - - -def _schedule_board_refresh() -> None: - board_store.request_refresh() - dashboard_store.request_refresh() - supervisor_store.request_refresh() - - -async def _run_archive_sync_once() -> dict: - global _last_archive_sync - init_archive_db() - settings = load_settings() - targets = enabled_exchanges(settings) - results: list[dict] = [] - for ex in targets: - ex_key = str(ex.get("key") or "").strip().lower() - if not ex_key: - continue - trades_resp = await asyncio.to_thread( - _fetch_instance_trades_archive_sync, - ex, - days=ARCHIVE_TRADE_DAYS, - limit=ARCHIVE_TRADE_LIMIT, - ) - if not trades_resp.get("ok"): - st = trades_resp.get("status") - msg = ( - trades_resp.get("msg") - or trades_resp.get("error") - or trades_resp.get("detail") - or "拉取交易失败" - ) - if st == 404: - msg = ( - "HTTP 404:该 Flask 未注册 /api/hub/trades/archive。" - "请在仓库根目录 git pull 后 pm2 restart crypto_gate crypto_gate_bot" - ) - results.append( - { - "exchange_key": ex_key, - "name": ex.get("name"), - "ok": False, - "status": st, - "msg": msg, - } - ) - continue - trades = trades_resp.get("trades") or [] - for t in trades: - if isinstance(t, dict): - t["exchange_key"] = ex_key - - def remote_fetch(**kwargs): - return _fetch_instance_ohlcv_sync( - ex, - symbol=kwargs.get("symbol") or "", - timeframe=kwargs.get("timeframe") or "5m", - since_ms=kwargs.get("since_ms"), - limit=int(kwargs.get("limit") or 500), - ) - - r = await asyncio.to_thread( - sync_exchange_symbol_archives, - ex_key, - trades, - remote_fetch, - ) - r["name"] = ex.get("name") - r["trade_count"] = len(trades) - results.append(r) - out = { - "ok": True, - "exchanges": len(targets), - "results": results, - "updated_at": __import__("datetime").datetime.now().isoformat(timespec="seconds"), - } - _last_archive_sync = out - return out - - -def _fetch_instance_volume_rank_sync(ex: dict, *, top_n: int = TOP_N_DEFAULT) -> dict: - base = (ex.get("flask_url") or "").rstrip("/") - if not base: - return {"ok": False, "msg": "未配置 flask_url"} - params = {"top": str(int(top_n))} - url = f"{base}/api/hub/volume-rank?{urlencode(params)}" - try: - with httpx.Client(timeout=max(HUB_FLASK_TIMEOUT, 120.0)) as client: - r = client.get(url, headers=_hub_headers()) - if r.status_code >= 400: - parsed = _parse_http_json_body(r) - parsed.setdefault("ok", False) - parsed.setdefault("status", r.status_code) - return parsed - data = r.json() if r.content else {} - return data if isinstance(data, dict) else {"ok": False, "msg": "无效 JSON"} - except Exception as e: - return {"ok": False, "msg": str(e)} - - -def _get_volume_rank_cache() -> dict: - global _volume_rank_cache - if _volume_rank_cache is None: - _volume_rank_cache = load_volume_rank_cache() - return _volume_rank_cache - - -def _refresh_volume_ranks(*, force: bool = False) -> dict: - global _volume_rank_cache - expected = rank_date_label() - cache = _get_volume_rank_cache() - targets = enabled_exchanges(load_settings()) - required_keys = [ - str(ex.get("key") or "").strip().lower() - for ex in targets - if ex.get("enabled") and str(ex.get("key") or "").strip() - ] - if not force and not cache_needs_refresh( - cache, expected_rank_date=expected, required_keys=required_keys - ): - return { - "ok": True, - "skipped": True, - "rank_date": cache.get("rank_date"), - "updated_at": cache.get("updated_at"), - } - errors: list[str] = [] - for ex in targets: - ex_key = str(ex.get("key") or "").strip().lower() - if not ex_key or not ex.get("enabled"): - continue - resp = _fetch_instance_volume_rank_sync(ex, top_n=TOP_N_DEFAULT) - if resp.get("ok") and resp.get("items"): - cache = merge_exchange_rank(cache, ex_key, resp) - else: - msg = str(resp.get("msg") or resp.get("error") or "拉取失败") - if resp.get("ok") and not resp.get("items"): - msg = msg if msg != "拉取失败" else "无有效成交额数据" - errors.append(f"{ex_key}:{msg}") - exchanges = dict(cache.get("exchanges") or {}) - prev = dict(exchanges.get(ex_key) or {}) - prev["error"] = msg - if not prev.get("items"): - prev["items"] = [] - exchanges[ex_key] = prev - cache["exchanges"] = exchanges - cache["rank_date"] = expected - save_volume_rank_cache(cache) - _volume_rank_cache = cache - out: dict = { - "ok": True, - "rank_date": expected, - "exchanges": len(targets), - "updated_at": cache.get("updated_at"), - } - if errors: - out["errors"] = errors[:8] - return out - - -async def _volume_rank_loop() -> None: - global _volume_rank_stop - stop = _volume_rank_stop - if stop is None: - return - try: - await asyncio.to_thread(_refresh_volume_ranks, force=False) - except Exception: - pass - while not stop.is_set(): - try: - wait_sec = seconds_until_next_reset() - await asyncio.wait_for(stop.wait(), timeout=wait_sec) - break - except asyncio.TimeoutError: - pass - if stop.is_set(): - break - try: - await asyncio.to_thread(_refresh_volume_ranks, force=True) - except Exception: - pass - - -async def _archive_sync_loop() -> None: - global _archive_sync_stop - stop = _archive_sync_stop - if stop is None: - return - init_archive_db() - while not stop.is_set(): - try: - await _run_archive_sync_once() - except Exception: - pass - try: - await asyncio.wait_for(stop.wait(), timeout=float(ARCHIVE_SYNC_INTERVAL_SEC)) - except asyncio.TimeoutError: - pass - - -async def _run_supervisor_tick() -> dict: - dash = dashboard_store.snapshot_dict() - board = board_store.snapshot_dict() - settings = load_settings() - ai_fn = make_supervisor_ai_reply_fn(_all_exchanges_for_ai()) - return await asyncio.to_thread( - process_supervisor_tick, - dash if dash.get("ok") is not False else None, - board if board.get("ok") is not False else None, - settings, - reset_hour=trading_day_reset_hour(), - ai_reply_fn=ai_fn, - ) - - -@asynccontextmanager -async def _hub_lifespan(_app: FastAPI): - global _archive_sync_stop, _archive_sync_task, _volume_rank_stop, _volume_rank_task - set_supervisor_notify_hook(supervisor_store.bump) - await board_store.start(_run_board_aggregate) - await dashboard_store.start(_run_dashboard_aggregate) - await supervisor_store.start(_run_supervisor_tick) - await chart_poll_store.start(_run_chart_poll) - _archive_sync_stop = asyncio.Event() - _archive_sync_task = asyncio.create_task(_archive_sync_loop(), name="hub-archive-sync") - _volume_rank_stop = asyncio.Event() - _volume_rank_task = asyncio.create_task(_volume_rank_loop(), name="hub-volume-rank") - try: - yield - finally: - if _archive_sync_stop: - _archive_sync_stop.set() - if _archive_sync_task: - _archive_sync_task.cancel() - try: - await _archive_sync_task - except asyncio.CancelledError: - pass - _archive_sync_task = None - _archive_sync_stop = None - if _volume_rank_stop: - _volume_rank_stop.set() - if _volume_rank_task: - _volume_rank_task.cancel() - try: - await _volume_rank_task - except asyncio.CancelledError: - pass - _volume_rank_task = None - _volume_rank_stop = None - await chart_poll_store.stop() - await supervisor_store.stop() - await dashboard_store.stop() - await board_store.stop() - set_supervisor_notify_hook(None) - - -app = FastAPI(title="复盘系统中控", docs_url=None, redoc_url=None, lifespan=_hub_lifespan) -STATIC_DIR = DIR / "static" -_REPO_STATIC = _REPO_ROOT / "static" -_AI_REVIEW_RENDER_JS = _REPO_STATIC / "ai_review_render.js" -_TRADE_STATS_CALENDAR_CSS = _REPO_STATIC / "trade_stats_calendar.css" -_TRADE_STATS_CALENDAR_JS = _REPO_STATIC / "trade_stats_calendar.js" -_ACCOUNT_RISK_BADGE_CSS = _REPO_STATIC / "account_risk_badge.css" -_ACCOUNT_RISK_BADGE_JS = _REPO_STATIC / "account_risk_badge.js" - - -@app.get("/assets/account_risk_badge.css") -def hub_account_risk_badge_css(): - """与四所实例共用仓库根 static/account_risk_badge.css。""" - if not _ACCOUNT_RISK_BADGE_CSS.is_file(): - raise HTTPException(status_code=404, detail="account_risk_badge.css not found") - return FileResponse( - str(_ACCOUNT_RISK_BADGE_CSS), - media_type="text/css; charset=utf-8", - ) - - -@app.get("/assets/account_risk_badge.js") -def hub_account_risk_badge_js(): - """与四所实例共用仓库根 static/account_risk_badge.js。""" - if not _ACCOUNT_RISK_BADGE_JS.is_file(): - raise HTTPException(status_code=404, detail="account_risk_badge.js not found") - return FileResponse( - str(_ACCOUNT_RISK_BADGE_JS), - media_type="application/javascript; charset=utf-8", - ) - - -@app.get("/assets/ai_review_render.js") -def hub_ai_review_render_js(): - """与四所实例共用仓库根 static/ai_review_render.js(须在 /assets mount 之前注册)。""" - if not _AI_REVIEW_RENDER_JS.is_file(): - raise HTTPException(status_code=404, detail="ai_review_render.js not found") - return FileResponse( - str(_AI_REVIEW_RENDER_JS), - media_type="application/javascript; charset=utf-8", - ) - - -@app.get("/assets/trade_stats_calendar.css") -def hub_trade_stats_calendar_css(): - if not _TRADE_STATS_CALENDAR_CSS.is_file(): - raise HTTPException(status_code=404, detail="trade_stats_calendar.css not found") - return FileResponse( - str(_TRADE_STATS_CALENDAR_CSS), - media_type="text/css; charset=utf-8", - ) - - -@app.get("/assets/trade_stats_calendar.js") -def hub_trade_stats_calendar_js(): - if not _TRADE_STATS_CALENDAR_JS.is_file(): - raise HTTPException(status_code=404, detail="trade_stats_calendar.js not found") - return FileResponse( - str(_TRADE_STATS_CALENDAR_JS), - media_type="application/javascript; charset=utf-8", - ) - - -if STATIC_DIR.is_dir(): - app.mount("/assets", StaticFiles(directory=str(STATIC_DIR)), name="assets") - - -@app.middleware("http") -async def local_only(request: Request, call_next): - if HUB_ALLOW_PUBLIC: - return await call_next(request) - peer = request.client.host if request.client else None - if not _client_allowed(peer): - return JSONResponse({"detail": "forbidden"}, status_code=403) - return await call_next(request) - - -@app.middleware("http") -async def embed_frame_headers(request: Request, call_next): - response = await call_next(request) - if embed_allowed(): - ancestors = embed_frame_ancestors() - if ancestors == "*": - response.headers["Content-Security-Policy"] = "frame-ancestors *" - else: - response.headers["Content-Security-Policy"] = f"frame-ancestors 'self' {ancestors}" - return response - - -@app.middleware("http") -async def hub_password_gate(request: Request, call_next): - if not password_required(): - return await call_next(request) - path = request.url.path - if is_public_path(path, request.method): - return await call_next(request) - token = request.cookies.get(SESSION_COOKIE) - if validate_session_token(token): - return await call_next(request) - if path.startswith("/api/"): - return JSONResponse({"detail": "未登录", "login_required": True}, status_code=401) - from fastapi.responses import RedirectResponse - - nxt = path if path.startswith("/") else "/monitor" - return RedirectResponse(f"/login?next={nxt}", status_code=302) - - -def _shell_page(): - index = STATIC_DIR / "index.html" - if not index.is_file(): - return JSONResponse({"detail": "missing static/index.html"}, status_code=500) - return FileResponse(index) - - -def _login_page(): - login = STATIC_DIR / "login.html" - if not login.is_file(): - return JSONResponse({"detail": "missing static/login.html"}, status_code=500) - return FileResponse(login) - - -class LoginBody(BaseModel): - username: str = "" - password: str = "" - - -@app.get("/api/auth/status") -def api_auth_status(request: Request): - required = password_required() - logged_in = not required or validate_session_token(request.cookies.get(SESSION_COOKIE)) - return { - "required": required, - "logged_in": logged_in, - } - - -@app.post("/api/auth/login") -def api_auth_login(body: LoginBody, request: Request): - if not password_required(): - return {"ok": True, "auth_disabled": True} - if not verify_credentials(body.username, body.password): - raise HTTPException(status_code=401, detail="用户名或密码错误") - token = create_session_token(body.username) - embed = (request.headers.get("x-hub-embed") or "").strip() == "1" - resp = JSONResponse({"ok": True, "session_token": token, "embed": embed}) - set_session_cookie(resp, request, token, embed=embed) - return resp - - -@app.get("/embed-auth") -def embed_auth_login(request: Request, token: str = "", next: str = "/monitor"): - """ - 嵌入式打开:父页跨域 fetch 登录时 Cookie 可能写不进 iframe, - 用 session_token 在本页做一次导航,在 iframe 内写入 hub_sess。 - """ - from fastapi.responses import RedirectResponse - - dest = safe_next_path(next) - if not password_required(): - return RedirectResponse(dest, status_code=302) - if not validate_session_token(token): - q = urlencode({"next": dest, "embed": "1"}) - return RedirectResponse(f"/login?{q}", status_code=302) - resp = RedirectResponse(dest, status_code=302) - set_session_cookie(resp, request, token, embed=True) - return resp - - -@app.post("/api/auth/logout") -def api_auth_logout(request: Request): - embed = (request.headers.get("x-hub-embed") or "").strip() == "1" - resp = JSONResponse({"ok": True}) - clear_session_cookie(resp, request, embed=embed) - return resp - - -@app.get("/login") -def login_page(): - return _login_page() - - -@app.get("/") -def root_redirect(): - from fastapi.responses import RedirectResponse - - return RedirectResponse("/monitor") - - -@app.get("/monitor") -@app.get("/plan") -@app.get("/calculator") -@app.get("/market") -@app.get("/archive") -@app.get("/dashboard") -@app.get("/funds") -@app.get("/ai") -@app.get("/settings") -def shell_pages(): - return _shell_page() - - -def _all_exchanges_for_ai() -> list: - """AI 聚合用:含未启用账户(标记未监控)。""" - return list(load_settings().get("exchanges") or []) - - -from hub_ai.routes import create_hub_ai_router -from hub_dashboard import build_dashboard_payload, default_trading_day - -app.include_router(create_hub_ai_router(load_all_exchanges=_all_exchanges_for_ai)) - - -async def _run_dashboard_aggregate() -> dict: - try: - return await asyncio.to_thread( - build_dashboard_payload, - enabled_exchanges(), - trading_day=default_trading_day(), - ) - except Exception as exc: - return {"ok": False, "msg": str(exc), "error": "aggregate_failed"} - - -def _schedule_dashboard_refresh() -> None: - dashboard_store.request_refresh() - supervisor_store.request_refresh() - - -@app.get("/api/dashboard/daily") -def api_dashboard_daily(trading_day: str = ""): - day = (trading_day or "").strip()[:10] or default_trading_day() - if not (trading_day or "").strip(): - return dashboard_store.snapshot_dict() - try: - payload = build_dashboard_payload( - enabled_exchanges(), - trading_day=day, - ) - except Exception as exc: - raise HTTPException(status_code=502, detail=str(exc)) from exc - return {**payload, "dashboard_version": dashboard_store.version} - - -@app.get("/api/dashboard/stream") -async def api_dashboard_stream(): - from fastapi.responses import StreamingResponse - - return StreamingResponse( - dashboard_store.iter_sse(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - - -@app.post("/api/dashboard/refresh") -async def api_dashboard_refresh(): - _schedule_dashboard_refresh() - return {"ok": True, "dashboard_version": dashboard_store.version} - - -@app.get("/api/ai/supervisor/stream") -async def api_supervisor_stream(): - from fastapi.responses import StreamingResponse - - return StreamingResponse( - supervisor_store.iter_sse(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - - -@app.post("/api/ai/supervisor/refresh") -async def api_supervisor_refresh(): - supervisor_store.request_refresh() - return {"ok": True, "supervisor_version": supervisor_store.version} - - -@app.get("/trade") -def trade_removed_redirect(): - from fastapi.responses import RedirectResponse - - return RedirectResponse("/monitor", status_code=302) - - -@app.get("/api/settings") -def api_get_settings(): - return load_settings() - - -class SettingsDisplayBody(BaseModel): - show_account_pnl: bool = True - show_nav_funds: bool = True - show_nav_dashboard: bool = True - show_nav_plan: bool = True - show_nav_archive: bool = True - show_nav_ai: bool = True - show_nav_calculator: bool = True - - -class SupervisorSettingsBody(BaseModel): - enabled: bool = True - wechat_webhook: str = "" - wechat_link_base: str = "" - wechat_prefix: str = "【交易监管】" - wechat_on_program_tp_sl: bool = True - manual_close_daily_warn: int = 2 - interval_warn_minutes: int = 15 - freq_30m_count: int = 2 - reopen_after_close_minutes: int = 30 - - -class SettingsBody(BaseModel): - exchanges: list[dict] = Field(default_factory=list) - display: SettingsDisplayBody | None = None - supervisor: SupervisorSettingsBody | None = None - - -@app.post("/api/settings") -def api_save_settings(body: SettingsBody): - force_off = env_force_disabled_ids() - to_save = [] - for ex in body.exchanges: - row = dict(ex) - eid = str(row.get("id", "")).strip() - if eid in force_off: - row["enabled"] = False - row.pop("env_disabled", None) - to_save.append(row) - existing = load_settings() - display = normalize_display_prefs(existing.get("display")) - if body.display is not None: - display = normalize_display_prefs(body.display.model_dump()) - supervisor = normalize_supervisor_settings(existing.get("supervisor")) - if body.supervisor is not None: - supervisor = normalize_supervisor_settings(body.supervisor.model_dump()) - save_settings({"version": 1, "exchanges": to_save, "display": display, "supervisor": supervisor}) - return {"ok": True, "settings": load_settings()} - - -class TrendCalculatorBody(BaseModel): - direction: str = "long" - capital_usdt: float = Field(gt=0) - risk_percent: float = Field(gt=0, le=100) - leverage: int = Field(ge=1, le=125) - entry_price: float = Field(gt=0) - stop_loss: float = Field(gt=0) - add_upper: float = Field(gt=0) - take_profit: float = Field(gt=0) - dca_legs: int = Field(default=5, ge=1, le=20) - exchange_id: str = "0" - base: str = "ETH" - - -class RollAddLegBody(BaseModel): - add_price: float = Field(gt=0) - new_stop_loss: float = Field(gt=0) - - -class RollCalculatorBody(BaseModel): - direction: str = "long" - capital_usdt: float = Field(gt=0) - risk_percent: float = Field(gt=0, le=100) - entry_price: float = Field(gt=0) - stop_loss: float = Field(gt=0) - take_profit: float = Field(gt=0) - add_legs: list[RollAddLegBody] = Field(default_factory=list, max_length=3) - legs_done: int = Field(default=0, ge=0, le=3) - exchange_id: str = "0" - base: str = "ETH" - - -@app.get("/api/calculator/exchanges") -def api_calculator_exchanges(): - from hub_calculator_market_lib import list_calculator_exchanges - - return {"ok": True, "data": list_calculator_exchanges()} - - -@app.get("/api/calculator/market") -def api_calculator_market(exchange_id: str = "0", base: str = "ETH"): - from hub_calculator_market_lib import get_calculator_market - - data, err = get_calculator_market(exchange_id, base) - if err: - return JSONResponse({"ok": False, "msg": err}, status_code=400) - return {"ok": True, "data": data} - - -@app.post("/api/calculator/trend") -def api_calculator_trend(body: TrendCalculatorBody): - from hub_calculator_lib import calc_trend_calculator - - data, err = calc_trend_calculator( - direction=body.direction, - capital_usdt=body.capital_usdt, - risk_percent=body.risk_percent, - leverage=body.leverage, - entry_price=body.entry_price, - stop_loss=body.stop_loss, - add_upper=body.add_upper, - take_profit=body.take_profit, - dca_legs=body.dca_legs, - exchange_id=body.exchange_id, - base=body.base, - ) - if err: - return JSONResponse({"ok": False, "msg": err}, status_code=400) - return {"ok": True, "data": data} - - -@app.post("/api/calculator/roll") -def api_calculator_roll(body: RollCalculatorBody): - from hub_calculator_lib import calc_roll_calculator - - data, err = calc_roll_calculator( - direction=body.direction, - capital_usdt=body.capital_usdt, - risk_percent=body.risk_percent, - entry_price=body.entry_price, - stop_loss=body.stop_loss, - take_profit=body.take_profit, - add_legs=[leg.model_dump() for leg in body.add_legs], - legs_done=body.legs_done, - exchange_id=body.exchange_id, - base=body.base, - ) - if err: - return JSONResponse({"ok": False, "msg": err}, status_code=400) - return {"ok": True, "data": data} - - -def _find_exchange_by_key(exchange_key: str) -> dict | None: - key = (exchange_key or "").strip().lower() - if not key: - return None - for ex in load_settings().get("exchanges") or []: - if str(ex.get("key") or "").strip().lower() == key: - return ex - if str(ex.get("id") or "").strip() == exchange_key.strip(): - return ex - return None - - -def _fetch_instance_trades_archive_sync( - ex: dict, - *, - days: int = 365, - limit: int = 2000, -) -> dict: - base = (ex.get("flask_url") or "").rstrip("/") - if not base: - return {"ok": False, "msg": "未配置 flask_url"} - params = {"days": str(int(days)), "limit": str(int(limit))} - url = f"{base}/api/hub/trades/archive?{urlencode(params)}" - try: - with httpx.Client(timeout=HUB_FLASK_TIMEOUT) as client: - r = client.get(url, headers=_hub_headers()) - if r.status_code >= 400: - parsed = _parse_http_json_body(r) - parsed.setdefault("ok", False) - parsed.setdefault("status", r.status_code) - return parsed - data = r.json() if r.content else {} - if isinstance(data, dict): - data.setdefault("ok", True) - return data - return {"ok": False, "msg": "无效 JSON"} - except Exception as e: - return {"ok": False, "msg": str(e)} - - -def _fetch_instance_ohlcv_sync( - ex: dict, - *, - symbol: str, - timeframe: str, - since_ms: int | None, - limit: int, -) -> dict: - base = (ex.get("flask_url") or "").rstrip("/") - if not base: - return {"ok": False, "msg": "未配置 flask_url"} - params = {"symbol": symbol, "timeframe": timeframe, "limit": str(int(limit))} - if since_ms is not None and int(since_ms) > 0: - params["since_ms"] = str(int(since_ms)) - url = f"{base}/api/hub/ohlcv?{urlencode(params)}" - try: - with httpx.Client(timeout=HUB_FLASK_TIMEOUT) as client: - r = client.get(url, headers=_hub_headers()) - if r.status_code >= 400: - parsed = _parse_http_json_body(r) - parsed.setdefault("ok", False) - return parsed - data = r.json() if r.content else {} - return data if isinstance(data, dict) else {"ok": False, "msg": "无效 JSON"} - except Exception as e: - return {"ok": False, "msg": str(e)} - - -@app.get("/api/chart/meta") -def api_chart_meta(): - tfs = [tf for tf in CHART_TIMEFRAME_ORDER if tf in CHART_TIMEFRAMES] - exchanges = [] - for ex in enabled_exchanges(load_settings()): - exchanges.append( - { - "id": ex.get("id"), - "key": ex.get("key"), - "name": ex.get("name"), - } - ) - return { - "ok": True, - "timeframes": [tf for tf in tfs if tf in CHART_TIMEFRAMES], - "retention_days": retention_days(), - "retention_policy": retention_policy_meta(), - "limits": {tf: bar_limit_for_timeframe(tf) for tf in tfs if tf in CHART_TIMEFRAMES}, - "initial_limits": {tf: chart_initial_limit(tf) for tf in tfs if tf in CHART_TIMEFRAMES}, - "chunk_limits": {tf: chart_chunk_limit(tf) for tf in tfs if tf in CHART_TIMEFRAMES}, - "memory_caps": {tf: chart_memory_cap(tf) for tf in tfs if tf in CHART_TIMEFRAMES}, - "exchanges": exchanges, - "volume_rank_top_n": TOP_N_DEFAULT, - "volume_rank_reset_hour": volume_rank_reset_hour(), - } - - -@app.get("/api/chart/volume-rank") -def api_chart_volume_rank(exchange_key: str = "", refresh: str = ""): - force = (refresh or "").strip().lower() in ("1", "true", "yes", "on") - if force: - result = _refresh_volume_ranks(force=True) - if not result.get("ok"): - raise HTTPException(status_code=502, detail=result.get("msg") or "刷新失败") - cache = _get_volume_rank_cache() - ex_k = (exchange_key or "").strip().lower() - targets = enabled_exchanges(load_settings()) - required_keys = [ - str(ex.get("key") or "").strip().lower() - for ex in targets - if ex.get("enabled") and str(ex.get("key") or "").strip() - ] - need_keys = [ex_k] if ex_k else required_keys - if cache_needs_refresh(cache, required_keys=need_keys): - _refresh_volume_ranks(force=True) - cache = _get_volume_rank_cache() - elif ex_k: - row = (cache.get("exchanges") or {}).get(ex_k) or {} - if _exchange_rank_row_stale(row): - _refresh_volume_ranks(force=True) - cache = _get_volume_rank_cache() - if ex_k: - ex = _find_exchange_by_key(ex_k) - if not ex: - raise HTTPException(status_code=400, detail="交易所不存在") - payload = get_cached_rank(cache, ex_k, top_n=TOP_N_DEFAULT) - payload["items"] = [ - {**row, "volume_label": format_volume_quote(row.get("volume_quote"))} - for row in payload.get("items") or [] - ] - payload["reset_hour"] = volume_rank_reset_hour() - err = ((cache.get("exchanges") or {}).get(ex_k) or {}).get("error") - if err and not payload.get("items"): - payload["ok"] = False - payload["msg"] = err - return payload - exchanges_out = {} - for ex in enabled_exchanges(load_settings()): - key = str(ex.get("key") or "").strip().lower() - if not key: - continue - row = get_cached_rank(cache, key, top_n=TOP_N_DEFAULT) - row["name"] = ex.get("name") - row["items"] = [ - {**item, "volume_label": format_volume_quote(item.get("volume_quote"))} - for item in row.get("items") or [] - ] - exchanges_out[key] = row - return { - "ok": True, - "rank_date": cache.get("rank_date"), - "updated_at": cache.get("updated_at"), - "reset_hour": volume_rank_reset_hour(), - "exchanges": exchanges_out, - } - - -@app.post("/api/chart/volume-rank/refresh") -async def api_chart_volume_rank_refresh(): - result = await asyncio.to_thread(_refresh_volume_ranks, force=True) - if not result.get("ok"): - raise HTTPException(status_code=502, detail=result.get("msg") or "刷新失败") - return result - - -@app.get("/api/chart/ohlcv") -def api_chart_ohlcv( - exchange_key: str = "", - symbol: str = "", - timeframe: str = "1d", - refresh: str = "", - tail: str = "", - limit: int = 0, - before_ms: str = "", -): - ex = _find_exchange_by_key(exchange_key) - if not ex: - raise HTTPException(status_code=400, detail="交易所不存在") - if not ex.get("enabled"): - raise HTTPException(status_code=400, detail="该交易所未启用") - sym = (symbol or "").strip().upper() - if not sym: - raise HTTPException(status_code=400, detail="请输入币种") - ex_key = str(ex.get("key") or "").strip().lower() - force = (refresh or "").strip().lower() in ("1", "true", "yes", "on") - tail_refresh = (tail or "").strip().lower() in ("1", "true", "yes", "on") - lim = int(limit) if int(limit or 0) > 0 else None - bms_raw = (before_ms or "").strip() - bms = None - if bms_raw: - try: - bms = int(bms_raw) - except ValueError: - raise HTTPException(status_code=400, detail="before_ms 无效") - clear_db = force and not tail_refresh and bms is None - - def remote_fetch(**kwargs): - tf_use = kwargs.get("timeframe") or timeframe - return _fetch_instance_ohlcv_sync( - ex, - symbol=kwargs.get("symbol") or sym, - timeframe=tf_use, - since_ms=kwargs.get("since_ms"), - limit=int(kwargs.get("limit") or bar_limit_for_timeframe(tf_use)), - ) - - result = resolve_chart_bars( - ex_key, - sym, - timeframe, - remote_fetch, - force_refresh=force, - tail_refresh=tail_refresh, - clear_db=clear_db, - limit=lim, - before_ms=bms, - ) - if not result.get("ok"): - raise HTTPException(status_code=502, detail=result.get("msg") or "K线加载失败") - if not result.get("candles") and result.get("before_ms") is None: - raise HTTPException(status_code=502, detail=result.get("msg") or "无 K 线") - tick = result.get("price_tick") - last = result["candles"][-1] if result.get("candles") else None - result["ohlcv"] = format_ohlcv_detail( - { - "open": last.get("open") if last else None, - "high": last.get("high") if last else None, - "low": last.get("low") if last else None, - "close": last.get("close") if last else None, - "volume": last.get("volume") if last else None, - } - if last - else None, - tick, - ) - result["chart_version"] = chart_poll_store.version - result["series_version"] = chart_poll_store.series_version(ex_key, sym, timeframe) - result["chart_poll_interval_sec"] = HUB_CHART_POLL_INTERVAL - return result - - -class ChartWatchBody(BaseModel): - exchange_key: str = "" - symbol: str = "" - timeframe: str = "5m" - - -@app.post("/api/chart/watch") -async def api_chart_watch(body: ChartWatchBody = Body(...)): - ex_k = (body.exchange_key or "").strip().lower() - sym = (body.symbol or "").strip().upper() - tf = (body.timeframe or "5m").strip() - if not ex_k or not sym: - raise HTTPException(status_code=400, detail="缺少 exchange_key 或 symbol") - if tf not in CHART_TIMEFRAMES: - raise HTTPException(status_code=400, detail="不支持的周期") - key = chart_poll_store.touch_watch(ex_k, sym, tf) - chart_poll_store.request_refresh() - return { - "ok": True, - "series_key": key, - "series_version": chart_poll_store.series_version(ex_k, sym, tf), - "chart_version": chart_poll_store.version, - "watch_ttl_sec": HUB_CHART_WATCH_TTL_SEC, - } - - -@app.post("/api/chart/unwatch") -async def api_chart_unwatch(body: ChartWatchBody = Body(...)): - chart_poll_store.clear_watch(body.exchange_key, body.symbol, body.timeframe) - return {"ok": True} - - -@app.get("/api/chart/stream") -async def api_chart_stream(): - from fastapi.responses import StreamingResponse - - return StreamingResponse( - chart_poll_store.iter_sse(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - - -@app.get("/api/chart/poll/meta") -async def api_chart_poll_meta(): - return chart_poll_store.event_dict() - - -@app.get("/api/settings/meta") -def api_settings_meta(): - po = public_origin() - return { - "env_disabled_ids": sorted(env_force_disabled_ids()), - "hub_bridge_token_set": bool(HUB_BRIDGE_TOKEN), - "capability_options": ["key", "trend"], - "public_origin": f"{po[0]}://{po[1]}" if po else None, - "public_origin_hint": ( - "未设置 HUB_PUBLIC_ORIGIN 时,复盘链接若为 127.0.0.1,仅服务器本机浏览器可打开" - if not po - else "复盘/展示链接已替换为对外地址" - ), - "password_required": password_required(), - } - - -async def _fetch_agent_status(client: httpx.AsyncClient, ex: dict) -> dict: - url = f"{ex['agent_url'].rstrip('/')}/status" - try: - r = await client.get(url, headers=_agent_headers(), timeout=HUB_AGENT_TIMEOUT) - body = r.json() if r.content else {} - return { - "id": ex["id"], - "name": ex["name"], - "key": ex.get("key"), - "agent_url": ex["agent_url"], - "flask_url": ex.get("flask_url"), - "capabilities": ex.get("capabilities") or [], - "http_ok": r.status_code == 200, - "agent": body, - "error": body.get("error") if isinstance(body, dict) else None, - } - except Exception as e: - return { - "id": ex["id"], - "name": ex["name"], - "key": ex.get("key"), - "agent_url": ex["agent_url"], - "flask_url": ex.get("flask_url"), - "capabilities": ex.get("capabilities") or [], - "http_ok": False, - "error": str(e), - "agent": None, - } - - -def _parse_http_json_body(r: httpx.Response) -> dict: - text = (r.text or "").strip() - if not text: - return {"ok": False, "status": r.status_code, "text": "(empty body)"} - try: - data = r.json() - if isinstance(data, dict): - return data - return {"ok": False, "status": r.status_code, "text": text[:500]} - except Exception: - snippet = text[:500] - if snippet.lstrip().lower().startswith(" dict | None: - base = (ex.get("flask_url") or "").rstrip("/") - if not base: - return None - try: - if method == "GET": - r = await client.get(f"{base}{path}", headers=_hub_headers(), timeout=HUB_FLASK_TIMEOUT) - else: - headers = {**_hub_headers(), "Content-Type": "application/json"} - if json_body is not None: - r = await client.post( - f"{base}{path}", headers=headers, json=json_body, timeout=120.0 - ) - else: - r = await client.post( - f"{base}{path}", headers=headers, data=data, timeout=120.0 - ) - if r.status_code >= 400: - parsed = _parse_http_json_body(r) - parsed.setdefault("ok", False) - parsed.setdefault("status", r.status_code) - return parsed - return _parse_http_json_body(r) - except Exception as e: - return {"ok": False, "error": str(e)} - - -async def _notify_instance_user_close( - client: httpx.AsyncClient, ex: dict, *, count: int = 1 -) -> dict | None: - """登记实例侧用户主动平仓风控(中控点平仓/全平)。""" - if count <= 0 or not (ex.get("flask_url") or "").strip(): - return None - return await _fetch_flask_json( - client, - ex, - "/api/hub/account-risk/user-close", - method="POST", - json_body={"source": "user_hub", "count": int(count)}, - ) - - -def _flask_error_from_hub_mon(hub_mon: dict | None) -> str | None: - if not isinstance(hub_mon, dict) or hub_mon.get("ok") is not False: - return None - st = hub_mon.get("status") - if st == 404: - return ( - "HTTP 404:该 Flask 未注册 /api/hub/*(hub_bridge 未加载)。" - "请在仓库根目录 git pull 后 pm2 restart crypto_binance crypto_gate crypto_gate_bot," - "并查看启动日志是否含 [hub_bridge] ImportError" - ) - return ( - hub_mon.get("msg") - or hub_mon.get("error") - or (f"HTTP {st}" if st else None) - or (str(hub_mon.get("text") or "")[:120] or None) - ) - - -def _cond_order_trigger_key(price: object) -> str | None: - if price is None or price == "": - return None - try: - return f"{float(price):.12g}" - except (TypeError, ValueError): - return None - - -def _merge_conditional_orders_no_dup( - existing: list, extra: list -) -> list: - """子代理已拉到的条件单与 Flask exchange_tpsl 合成行按触发价/订单号去重,避免 Gate 显示 4 笔实为 2 笔。""" - if not extra: - return list(existing) if existing else [] - if not existing: - return list(extra) - triggers: set[str] = set() - order_ids: set[str] = set() - out: list = [] - for row in existing: - if not isinstance(row, dict): - continue - out.append(row) - k = _cond_order_trigger_key(row.get("trigger_price")) - if k: - triggers.add(k) - oid = row.get("id") - if oid not in (None, ""): - order_ids.add(str(oid)) - for row in extra: - if not isinstance(row, dict): - continue - k = _cond_order_trigger_key(row.get("trigger_price")) - oid = row.get("id") - if k and k in triggers: - continue - if oid not in (None, "") and str(oid) in order_ids: - continue - out.append(row) - if k: - triggers.add(k) - if oid not in (None, ""): - order_ids.add(str(oid)) - return out - - -def _tpsl_slots_to_conditional_orders(exchange_tpsl: dict, symbol: str) -> list[dict]: - """将实例 price_snapshot 的 exchange_tpsl 转为中控条件单结构。""" - out: list[dict] = [] - if not isinstance(exchange_tpsl, dict): - return out - for role, label in (("sl", "止损"), ("tp", "止盈")): - slot = exchange_tpsl.get(role) - if not isinstance(slot, dict): - continue - trig = slot.get("trigger_price") - if trig is None: - continue - try: - trig_f = float(trig) - except (TypeError, ValueError): - continue - oid = slot.get("order_id") - out.append( - { - "id": str(oid) if oid is not None else "", - "symbol": symbol, - "channel": "algo", - "category": "conditional", - "label": f"{label} {trig_f:g}", - "trigger_price": trig_f, - "amount": slot.get("amount"), - "status": "open", - } - ) - return out - - -def _exchange_tpsl_from_hub_order(hub_orders: list, symbol: str, side: str) -> dict | None: - """趋势保本移交后:用下单监控计划价补全 exchange_tpsl(与实例页一致)。""" - side_l = (side or "").lower() - for o in hub_orders: - if not isinstance(o, dict): - continue - o_sym = o.get("exchange_symbol") or o.get("symbol") or "" - if not _symbols_match(symbol, o_sym): - continue - if (o.get("direction") or "").lower() != side_l: - continue - sl = o.get("stop_loss") - tp = o.get("take_profit") - if sl in (None, "") and tp in (None, ""): - continue - slots: dict = {"sl": None, "tp": None} - if sl not in (None, ""): - try: - slots["sl"] = {"trigger_price": float(sl), "order_id": None} - except (TypeError, ValueError): - pass - if tp not in (None, ""): - try: - slots["tp"] = {"trigger_price": float(tp), "order_id": None} - except (TypeError, ValueError): - pass - if slots["sl"] or slots["tp"]: - return slots - return None - - -def _find_exchange_tpsl_for_position( - symbol: str, - side: str, - order_prices: list, - hub_orders: list, -) -> dict | None: - side_l = (side or "").lower() - op_by_id = { - op.get("id"): op - for op in order_prices - if isinstance(op, dict) and op.get("id") is not None - } - for o in hub_orders: - if not isinstance(o, dict): - continue - o_sym = o.get("exchange_symbol") or o.get("symbol") or "" - if not _symbols_match(symbol, o_sym): - continue - if (o.get("direction") or "").lower() != side_l: - continue - op = op_by_id.get(o.get("id")) - if not isinstance(op, dict): - continue - et = op.get("exchange_tpsl") - if isinstance(et, dict) and (et.get("sl") or et.get("tp")): - return et - for op in order_prices: - if not isinstance(op, dict): - continue - if not _symbols_match(symbol, op.get("symbol") or ""): - continue - et = op.get("exchange_tpsl") - if isinstance(et, dict) and (et.get("sl") or et.get("tp")): - return et - return None - - -def _merge_flask_order_price_fields(hub_mon: dict | None, snap: dict | None) -> None: - """将 price_snapshot 中的快照盈亏比、已保本状态合并进 hub_monitor.orders。""" - if not isinstance(hub_mon, dict) or not isinstance(snap, dict): - return - order_prices = snap.get("order_prices") or [] - op_by_id = { - op.get("id"): op - for op in order_prices - if isinstance(op, dict) and op.get("id") is not None - } - orders = hub_mon.get("orders") or [] - if not isinstance(orders, list): - return - for o in orders: - if not isinstance(o, dict): - continue - op = op_by_id.get(o.get("id")) - if not isinstance(op, dict): - continue - if op.get("rr_ratio") is not None: - o["rr_ratio"] = op["rr_ratio"] - if "sl_breakeven_secured" in op: - o["sl_breakeven_secured"] = bool(op["sl_breakeven_secured"]) - for key in ( - "stop_loss", - "take_profit", - "stop_loss_display", - "take_profit_display", - "display_rr_ratio", - "exchange_initial_margin", - "plan_margin", - "time_close_enabled", - "time_close_hours", - "time_close_at_ms", - "time_close_label", - "time_close_countdown", - "time_close_remaining_sec", - ): - if key in op and op[key] not in (None, ""): - o[key] = op[key] - - -def _merge_flask_position_breakeven(agent_row: dict, snap: dict | None, hub_mon: dict | None) -> None: - """将 price_snapshot 的已保本状态同步到 agent 持仓,供中控首页表格展示。""" - ag = agent_row.get("agent") - if not isinstance(ag, dict) or not isinstance(snap, dict): - return - positions = ag.get("positions") - if not isinstance(positions, list) or not positions: - return - order_prices = snap.get("order_prices") or [] - hub_orders = [] - if isinstance(hub_mon, dict): - hub_orders = hub_mon.get("orders") or [] - op_by_id = { - op.get("id"): op - for op in order_prices - if isinstance(op, dict) and op.get("id") is not None - } - for p in positions: - if not isinstance(p, dict): - continue - sym = p.get("symbol") or "" - side = (p.get("side") or "").lower() - matched = None - for o in hub_orders: - if not isinstance(o, dict): - continue - o_sym = o.get("exchange_symbol") or o.get("symbol") or "" - if not _symbols_match(sym, o_sym): - continue - if (o.get("direction") or "").lower() != side: - continue - matched = op_by_id.get(o.get("id")) - break - if matched is None: - for op in order_prices: - if not isinstance(op, dict): - continue - if not _symbols_match(sym, op.get("symbol") or ""): - continue - matched = op - break - if isinstance(matched, dict) and "sl_breakeven_secured" in matched: - p["sl_breakeven_secured"] = bool(matched["sl_breakeven_secured"]) - - -def _agent_position_has_mark(p: dict) -> bool: - try: - v = float(p.get("mark_price")) - return v > 0 - except (TypeError, ValueError): - return False - - -def _apply_agent_mark_price(p: dict, mark_price: object, mark_display: object = None) -> None: - try: - mpf = float(mark_price) - except (TypeError, ValueError): - return - if mpf <= 0: - return - p["mark_price"] = mpf - disp = mark_display - if disp is not None and str(disp).strip() not in ("", "-"): - p["mark_price_fmt"] = str(disp) - - -def _find_matched_order_price_op( - p: dict, - order_prices: list, - hub_orders: list, - op_by_id: dict, -) -> dict | None: - sym = p.get("symbol") or "" - side = (p.get("side") or "").lower() - for o in hub_orders: - if not isinstance(o, dict): - continue - o_sym = o.get("exchange_symbol") or o.get("symbol") or "" - if not _symbols_match(sym, o_sym): - continue - if (o.get("direction") or "").lower() != side: - continue - matched = op_by_id.get(o.get("id")) - if isinstance(matched, dict): - return matched - break - for op in order_prices: - if not isinstance(op, dict): - continue - if not _symbols_match(sym, op.get("symbol") or ""): - continue - return op - return None - - -def _merge_flask_position_mark_price( - agent_row: dict, snap: dict | None, hub_mon: dict | None -) -> None: - """子代理无标记价时,用实例 price_snapshot 的交易所标记价补全中控持仓展示。""" - ag = agent_row.get("agent") - if not isinstance(ag, dict) or not isinstance(snap, dict): - return - positions = ag.get("positions") - if not isinstance(positions, list) or not positions: - return - order_prices = snap.get("order_prices") or [] - hub_orders = [] - if isinstance(hub_mon, dict): - hub_orders = hub_mon.get("orders") or [] - op_by_id = { - op.get("id"): op - for op in order_prices - if isinstance(op, dict) and op.get("id") is not None - } - for p in positions: - if not isinstance(p, dict) or _agent_position_has_mark(p): - continue - matched = _find_matched_order_price_op(p, order_prices, hub_orders, op_by_id) - if isinstance(matched, dict): - _apply_agent_mark_price( - p, - matched.get("exchange_mark_price"), - matched.get("exchange_mark_price_display"), - ) - position_marks = snap.get("position_marks") or [] - if not isinstance(position_marks, list): - return - for p in positions: - if not isinstance(p, dict) or _agent_position_has_mark(p): - continue - sym = p.get("symbol") or "" - side = (p.get("side") or "").lower() - for pm in position_marks: - if not isinstance(pm, dict): - continue - if not _symbols_match(sym, pm.get("symbol") or ""): - continue - if (pm.get("side") or "").lower() != side: - continue - _apply_agent_mark_price( - p, pm.get("mark_price"), pm.get("mark_price_display") - ) - break - - -def _merge_flask_exchange_tpsl(agent_row: dict, snap: dict | None, hub_mon: dict | None) -> None: - """子代理挂单为空时,用实例 Flask 已算好的 exchange_tpsl 补全展示。""" - ag = agent_row.get("agent") - if not isinstance(ag, dict): - return - positions = ag.get("positions") - if not isinstance(positions, list) or not positions: - return - if not isinstance(snap, dict): - return - order_prices = snap.get("order_prices") or [] - hub_orders = [] - if isinstance(hub_mon, dict): - hub_orders = hub_mon.get("orders") or [] - for p in positions: - if not isinstance(p, dict): - continue - sym = p.get("symbol") or "" - side = p.get("side") or "" - et = _find_exchange_tpsl_for_position(sym, side, order_prices, hub_orders) - if not et: - et = _exchange_tpsl_from_hub_order(hub_orders, sym, side) - if not et: - continue - p["exchange_tpsl"] = et - cond = p.get("conditional_orders") or [] - merged = _tpsl_slots_to_conditional_orders(et, sym) - p["conditional_orders"] = _merge_conditional_orders_no_dup(cond, merged) - - -async def _fetch_exchange_flask_bundle( - client: httpx.AsyncClient, ex: dict -) -> tuple[dict | None, dict | None, list | None, dict | None, dict | None]: - """单所 Flask:monitor / meta / price_snapshot / account(有 flask_url 时)并行拉取。""" - caps = ex.get("capabilities") or [] - tasks = [ - _fetch_flask_json(client, ex, "/api/hub/monitor"), - _fetch_flask_json(client, ex, "/api/hub/meta"), - ] - has_flask = bool((ex.get("flask_url") or "").strip()) - if has_flask: - tasks.extend( - [ - _fetch_flask_json(client, ex, "/api/price_snapshot"), - _fetch_flask_json(client, ex, "/api/hub/account"), - ] - ) - results = await asyncio.gather(*tasks) - hub_mon = results[0] - meta = results[1] - snap = results[2] if has_flask and len(results) > 2 else None - account = results[3] if has_flask and len(results) > 3 else None - key_prices = None - want_prices = HUB_BOARD_KEY_PRICES and "key" in caps - if want_prices and isinstance(snap, dict): - key_prices = snap.get("key_prices") - return ( - hub_mon, - meta, - key_prices, - snap if isinstance(snap, dict) else None, - account if isinstance(account, dict) else None, - ) - - -async def _assemble_board_row( - client: httpx.AsyncClient, ex: dict, agent_row: dict -) -> dict: - hub_mon, meta, key_prices, snap, account = await _fetch_exchange_flask_bundle( - client, ex - ) - if isinstance(hub_mon, dict): - _merge_flask_order_price_fields(hub_mon, snap) - _merge_flask_exchange_tpsl(agent_row, snap, hub_mon if isinstance(hub_mon, dict) else None) - _merge_flask_position_breakeven(agent_row, snap, hub_mon if isinstance(hub_mon, dict) else None) - _merge_flask_position_mark_price(agent_row, snap, hub_mon if isinstance(hub_mon, dict) else None) - flask_ok = isinstance(hub_mon, dict) and hub_mon.get("ok") is not False - acct_ok = isinstance(account, dict) and account.get("ok") is not False - raw_review = (ex.get("review_url") or "").strip() - review_link = browser_url(raw_review) if raw_review else default_review_url( - ex.get("flask_url") - ) - return { - **agent_row, - "flask_url": ex.get("flask_url") or "", - "flask_url_browser": browser_url(ex.get("flask_url")), - "review_url": review_link, - "hub_monitor": hub_mon, - "flask_ok": flask_ok, - "flask_error": _flask_error_from_hub_mon(hub_mon if isinstance(hub_mon, dict) else None), - "meta": (meta or {}).get("meta") if isinstance(meta, dict) else meta, - "key_prices": key_prices, - "funding_usdt": account.get("funding_usdt") if acct_ok else None, - "trading_usdt": account.get("trading_usdt") if acct_ok else None, - "available_trading_usdt": account.get("available_trading_usdt") if acct_ok else None, - "account_ok": acct_ok, - } - - -async def _build_monitor_board_payload() -> dict: - exchanges = enabled_exchanges() - async with httpx.AsyncClient() as client: - agent_rows = await asyncio.gather( - *[_fetch_agent_status(client, ex) for ex in exchanges] - ) - out = await asyncio.gather( - *[ - _assemble_board_row(client, ex, agent_row) - for ex, agent_row in zip(exchanges, agent_rows) - ] - ) - return { - "rows": list(out), - "updated_at": __import__("datetime").datetime.now().isoformat(timespec="seconds"), - } - - -@app.get("/api/monitor/board") -@app.get("/api/monitor/board/snapshot") -async def api_monitor_board_snapshot(): - """读后台缓存快照;完整聚合由 hub 每 HUB_BOARD_POLL_INTERVAL 秒执行。""" - return board_store.snapshot_dict() - - -@app.get("/api/monitor/board/stream") -async def api_monitor_board_stream(): - from fastapi.responses import StreamingResponse - - return StreamingResponse( - board_store.iter_sse(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - - -@app.post("/api/monitor/board/refresh") -async def api_monitor_board_refresh(): - _schedule_board_refresh() - return {"ok": True, "board_version": board_store.version} - - -@app.get("/api/host/status") -async def api_host_status(): - from hub_host_status_lib import get_host_status - - return await asyncio.to_thread(get_host_status) - - -def _require_hub_logged_in(request: Request) -> None: - if password_required() and not validate_session_token(request.cookies.get(SESSION_COOKIE)): - raise HTTPException(status_code=401, detail="未登录中控") - - -@app.get("/api/instance/open-url") -def api_instance_open_url( - request: Request, - exchange_id: str, - next: str = "/", - embed: str = "", - hub_theme: str = "", -): - """已登录中控时生成实例 SSO 打开链接(2h 有效、单次使用,复用 HUB_BRIDGE_TOKEN)。""" - _require_hub_logged_in(request) - if not HUB_BRIDGE_TOKEN: - raise HTTPException(status_code=503, detail="未配置 HUB_BRIDGE_TOKEN,无法签发实例打开链接") - ex = _find_exchange(exchange_id) - if not ex: - raise HTTPException(status_code=404, detail="未知交易所 id") - base = browser_url((ex.get("flask_url") or "").strip()).rstrip("/") - if not base: - raise HTTPException(status_code=400, detail="该账户未配置 flask_url") - ex_key = (ex.get("key") or "").strip().lower() - if not ex_key: - raise HTTPException(status_code=400, detail="该账户缺少 key(用于 SSO 校验)") - nxt = safe_next_path(next) - token = mint_hub_sso_token(ex_key, nxt) - if not token: - raise HTTPException(status_code=503, detail="签发 SSO 失败") - params = {"token": token, "next": nxt} - if (embed or "").strip().lower() in ("1", "true", "yes", "on"): - params["embed"] = "1" - ht = (hub_theme or "").strip().lower() - if ht in ("light", "dark"): - params["hub_theme"] = ht - q = urlencode(params) - return { - "ok": True, - "url": f"{base}/hub-sso?{q}", - "expires_in": HUB_SSO_TTL_SEC, - "exchange_id": exchange_id, - "exchange_key": ex_key, - } - - -class CloseAllBody(BaseModel): - exclude_ids: list[str] = Field(default_factory=list) - - -class ClosePositionBody(BaseModel): - symbol: str - side: str - - -class CancelOrderBody(BaseModel): - symbol: str - order_id: str - channel: str = "regular" - - -class CancelSymbolOrdersBody(BaseModel): - symbol: str - scope: str = "all" - - -class PlaceTpslBody(BaseModel): - symbol: str - side: str - stop_loss: float - take_profit: float - contracts: float | None = None - - -class TrendPlanActionBody(BaseModel): - plan_id: int - breakeven_offset_pct: float | None = None - - -def _flask_hub_messages(parsed: dict | None) -> tuple[bool, str]: - if not isinstance(parsed, dict): - return False, "实例返回无效" - msgs = list(parsed.get("messages") or []) - if parsed.get("msg"): - msgs.insert(0, str(parsed["msg"])) - if parsed.get("error"): - msgs.append(str(parsed["error"])) - ok = parsed.get("ok") is not False - if parsed.get("ok") is True: - ok = True - elif parsed.get("ok") is False: - ok = False - else: - for m in msgs: - if any( - k in str(m) - for k in ("失败", "错误", "无法", "缺少", "过期", "未找到", "不允许", "异常") - ): - ok = False - break - text = ";".join(str(x) for x in msgs if x) or ("成功" if ok else "操作失败") - return ok, text - - -@app.post("/api/trend/{exchange_id}/stop") -async def api_trend_plan_stop(exchange_id: str, body: TrendPlanActionBody): - ex = _find_exchange(exchange_id) - if not ex or not ex.get("enabled"): - raise HTTPException(status_code=404, detail="账户未启用") - if "trend" not in (ex.get("capabilities") or []): - raise HTTPException(status_code=400, detail="该账户未启用趋势计划监控") - pid = int(body.plan_id) - async with httpx.AsyncClient() as client: - parsed = await _fetch_flask_json( - client, ex, f"/api/hub/trend/stop/{pid}", method="POST" - ) - ok, text = _flask_hub_messages(parsed) - _schedule_board_refresh() - return {"ok": ok, "message": text, "payload": parsed} - - -@app.post("/api/trend/{exchange_id}/breakeven") -async def api_trend_plan_breakeven(exchange_id: str, body: TrendPlanActionBody): - ex = _find_exchange(exchange_id) - if not ex or not ex.get("enabled"): - raise HTTPException(status_code=404, detail="账户未启用") - if "trend" not in (ex.get("capabilities") or []): - raise HTTPException(status_code=400, detail="该账户未启用趋势计划监控") - pid = int(body.plan_id) - data = {} - if body.breakeven_offset_pct is not None: - data["breakeven_offset_pct"] = str(body.breakeven_offset_pct) - async with httpx.AsyncClient() as client: - parsed = await _fetch_flask_json( - client, - ex, - f"/api/hub/trend/breakeven/{pid}", - method="POST", - data=data, - ) - ok, text = _flask_hub_messages(parsed) - _schedule_board_refresh() - return {"ok": ok, "message": text, "payload": parsed} - - -@app.post("/api/orders/{exchange_id}/cancel") -async def api_cancel_order(exchange_id: str, body: CancelOrderBody): - ex = _find_exchange(exchange_id) - if not ex or not ex.get("enabled"): - raise HTTPException(status_code=404, detail="账户未启用") - url = f"{ex['agent_url'].rstrip('/')}/orders/cancel" - async with httpx.AsyncClient() as client: - r = await client.post( - url, - headers=_agent_headers(), - json={ - "symbol": body.symbol, - "order_id": body.order_id, - "channel": body.channel or "regular", - }, - timeout=60.0, - ) - try: - payload = r.json() - except Exception: - payload = {"raw": (r.text or "")[:2000]} - out = { - "exchange": ex, - "status_code": r.status_code, - "payload": payload, - "ok": bool(isinstance(payload, dict) and payload.get("ok")), - } - _schedule_board_refresh() - return out - - -@app.post("/api/orders/{exchange_id}/cancel-symbol") -async def api_cancel_symbol_orders(exchange_id: str, body: CancelSymbolOrdersBody): - ex = _find_exchange(exchange_id) - if not ex or not ex.get("enabled"): - raise HTTPException(status_code=404, detail="账户未启用") - url = f"{ex['agent_url'].rstrip('/')}/orders/cancel-symbol" - async with httpx.AsyncClient() as client: - r = await client.post( - url, - headers=_agent_headers(), - json={"symbol": body.symbol, "scope": body.scope or "all"}, - timeout=120.0, - ) - try: - payload = r.json() - except Exception: - payload = {"raw": (r.text or "")[:2000]} - out = { - "exchange": ex, - "status_code": r.status_code, - "payload": payload, - "ok": bool(isinstance(payload, dict) and payload.get("ok")), - } - _schedule_board_refresh() - return out - - -@app.post("/api/close/{exchange_id}/position") -async def api_close_position(exchange_id: str, body: ClosePositionBody): - ex = _find_exchange(exchange_id) - if not ex or not ex.get("enabled"): - raise HTTPException(status_code=404, detail="账户未启用") - sym = (body.symbol or "").strip() - side = (body.side or "").strip().lower() - if not sym: - raise HTTPException(status_code=400, detail="symbol 不能为空") - if side not in ("long", "short"): - raise HTTPException(status_code=400, detail="side 须为 long 或 short") - url = f"{ex['agent_url'].rstrip('/')}/emergency/close-position" - async with httpx.AsyncClient() as client: - r = await client.post( - url, - headers=_agent_headers(), - json={"symbol": sym, "side": side}, - timeout=120.0, - ) - try: - payload = r.json() - except Exception: - payload = {"raw": (r.text or "")[:2000]} - out = { - "exchange": ex, - "status_code": r.status_code, - "payload": payload, - "ok": bool(isinstance(payload, dict) and payload.get("ok")), - } - if out.get("ok"): - ex_key = (ex.get("key") or "").strip().lower() - async with httpx.AsyncClient() as flask_client: - if ex_key in ("gate", "gate_bot"): - order_sync = await _fetch_flask_json( - flask_client, - ex, - "/api/hub/order/sync-flat", - method="POST", - json_body={"symbol": sym, "side": side}, - ) - if isinstance(order_sync, dict): - out["order_sync"] = order_sync - if "trend" in (ex.get("capabilities") or []): - sync_parsed = await _fetch_flask_json( - flask_client, - ex, - "/api/hub/trend/sync-flat", - method="POST", - json_body={"symbol": sym, "side": side}, - ) - if isinstance(sync_parsed, dict): - out["trend_sync"] = sync_parsed - roll_sync = await _fetch_flask_json( - flask_client, - ex, - "/api/hub/roll/sync-flat", - method="POST", - json_body={"symbol": sym, "side": side}, - ) - if isinstance(roll_sync, dict): - out["roll_sync"] = roll_sync - risk_sync = await _notify_instance_user_close(flask_client, ex, count=1) - if isinstance(risk_sync, dict): - out["risk_sync"] = risk_sync - _schedule_board_refresh() - return out - - -@app.post("/api/orders/{exchange_id}/place-tpsl") -async def api_place_tpsl(exchange_id: str, body: PlaceTpslBody): - ex = _find_exchange(exchange_id) - if not ex or not ex.get("enabled"): - raise HTTPException(status_code=404, detail="账户未启用") - url = f"{ex['agent_url'].rstrip('/')}/orders/place-tpsl" - async with httpx.AsyncClient() as client: - r = await client.post( - url, - headers=_agent_headers(), - json={ - "symbol": body.symbol, - "side": body.side, - "stop_loss": body.stop_loss, - "take_profit": body.take_profit, - "contracts": body.contracts, - }, - timeout=120.0, - ) - try: - payload = r.json() - except Exception: - payload = {"raw": (r.text or "")[:2000]} - out = { - "exchange": ex, - "status_code": r.status_code, - "payload": payload, - "ok": bool(isinstance(payload, dict) and payload.get("ok")), - } - _schedule_board_refresh() - return out - - -@app.post("/api/close/{exchange_id}") -async def api_close_exchange(exchange_id: str): - ex = _find_exchange(exchange_id) - if not ex or not ex.get("enabled"): - raise HTTPException(status_code=404, detail="账户未启用") - url = f"{ex['agent_url'].rstrip('/')}/emergency/close-all" - async with httpx.AsyncClient() as client: - r = await client.post(url, headers=_agent_headers(), timeout=120.0) - try: - body = r.json() - except Exception: - body = {"raw": (r.text or "")[:2000]} - ok = bool(isinstance(body, dict) and body.get("ok")) - out = {"exchange": ex, "status_code": r.status_code, "payload": body, "ok": ok} - if ok and isinstance(body, dict): - closed = body.get("closed") or [] - n = len(closed) if isinstance(closed, list) else 0 - if n > 0: - risk_sync = await _notify_instance_user_close(client, ex, count=n) - if isinstance(risk_sync, dict): - out["risk_sync"] = risk_sync - _schedule_board_refresh() - return out - - -@app.post("/api/close-all") -async def api_close_all(body: CloseAllBody | None = Body(default=None)): - excl = set(body.exclude_ids if body else []) - excl |= env_force_disabled_ids() - targets = [x for x in enabled_exchanges() if str(x["id"]) not in excl] - async with httpx.AsyncClient() as client: - - async def one(ex: dict): - url = f"{ex['agent_url'].rstrip('/')}/emergency/close-all" - try: - r = await client.post(url, headers=_agent_headers(), timeout=120.0) - try: - payload = r.json() - except Exception: - payload = {"raw": (r.text or "")[:2000]} - row = {"id": ex["id"], "name": ex["name"], "status_code": r.status_code, "payload": payload} - if isinstance(payload, dict) and payload.get("ok"): - closed = payload.get("closed") or [] - n = len(closed) if isinstance(closed, list) else 0 - if n > 0: - risk_sync = await _notify_instance_user_close(client, ex, count=n) - if isinstance(risk_sync, dict): - row["risk_sync"] = risk_sync - return row - except Exception as e: - return {"id": ex["id"], "name": ex["name"], "status_code": None, "error": str(e)} - - results = await asyncio.gather(*[one(ex) for ex in targets]) - _schedule_board_refresh() - return {"results": list(results)} - - -def _trade_removed_response(): - """旧版前端或缓存页面仍会请求 /api/trade/*,勿解析表单,直接返回说明。""" - return JSONResponse( - { - "ok": False, - "result": { - "ok": False, - "messages": [ - "中控已移除下单区。请在监控卡片点击「实例」," - "进入对应 crypto_monitor_* 网页添加关键位或下单。" - ], - }, - "deprecated": True, - }, - status_code=410, - ) - - -def _parse_anchor_ms(at: str = "", anchor_ms: str = "") -> int | None: - raw = (anchor_ms or at or "").strip() - if not raw: - return None - return parse_wall_clock_ms(raw) - - -@app.get("/api/archive/meta") -def api_archive_meta(): - init_archive_db() - exchanges = [] - for ex in enabled_exchanges(load_settings()): - exchanges.append( - { - "id": ex.get("id"), - "key": ex.get("key"), - "name": ex.get("name"), - } - ) - return { - "ok": True, - "timeframes": sorted(ARCHIVE_TIMEFRAMES), - "default_timeframe": ARCHIVE_DEFAULT_TIMEFRAME, - "seed_lookback_days": ARCHIVE_SEED_LOOKBACK_DAYS, - "sync_interval_sec": ARCHIVE_SYNC_INTERVAL_SEC, - "visible_bars_default": ARCHIVE_VISIBLE_BARS_DEFAULT, - "exchanges": exchanges, - "last_sync": _last_archive_sync, - } - - -@app.get("/api/archive/list") -def api_archive_list( - exchange_key: str = "", - filter_profit: str = "", - filter_loss: str = "", - filter_sick: str = "", - filter_emotion: str = "", -): - init_archive_db() - rows = list_symbol_rows( - exchange_key=exchange_key, - filter_profit=(filter_profit or "").lower() in ("1", "true", "yes", "on"), - filter_loss=(filter_loss or "").lower() in ("1", "true", "yes", "on"), - filter_sick=(filter_sick or "").lower() in ("1", "true", "yes", "on"), - filter_emotion=(filter_emotion or "").lower() in ("1", "true", "yes", "on"), - ) - return {"ok": True, "rows": rows, "count": len(rows)} - - -@app.get("/api/archive/daily-trades") -def api_archive_daily_trades( - period: str = "", - trading_day: str = "", - date_from: str = "", - date_to: str = "", - exchange_key: str = "", - filter_profit: str = "", - filter_loss: str = "", - filter_sick: str = "", - search: str = "", -): - init_archive_db() - payload = list_daily_trades( - trading_day=trading_day, - period=period or "today", - date_from=date_from, - date_to=date_to, - exchange_key=exchange_key, - filter_profit=(filter_profit or "").lower() in ("1", "true", "yes", "on"), - filter_loss=(filter_loss or "").lower() in ("1", "true", "yes", "on"), - filter_sick=(filter_sick or "").lower() in ("1", "true", "yes", "on"), - search=search, - ) - return {"ok": True, **payload} - - -@app.get("/api/archive/calendar") -def api_archive_calendar( - year: int = 0, - month: int = 0, - exchange_key: str = "", -): - init_archive_db() - if year <= 0 or month <= 0: - td = today_trading_day() - parts = td.split("-") - year = int(parts[0]) - month = int(parts[1]) - try: - payload = list_archive_calendar(year, month, exchange_key=exchange_key) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) from e - return {"ok": True, **payload} - - -@app.get("/api/archive/quotes") -def api_archive_quotes(): - init_archive_db() - rows = list_review_quotes() - return {"ok": True, "quotes": rows, "count": len(rows), "max": ARCHIVE_QUOTES_MAX} - - -class ArchiveQuoteBody(BaseModel): - quote_date: str = "" - content: str = "" - - -@app.post("/api/archive/quotes") -def api_archive_quote_create(body: ArchiveQuoteBody = Body(...)): - init_archive_db() - try: - row = create_review_quote(body.quote_date, body.content) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) from e - return {"ok": True, "quote": row} - - -@app.patch("/api/archive/quotes/{quote_id}") -def api_archive_quote_update(quote_id: int, body: ArchiveQuoteBody = Body(...)): - init_archive_db() - try: - row = update_review_quote( - int(quote_id), - quote_date=body.quote_date or None, - content=body.content if body.content is not None else None, - ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) from e - if not row: - raise HTTPException(status_code=404, detail="语录不存在") - return {"ok": True, "quote": row} - - -@app.delete("/api/archive/quotes/{quote_id}") -def api_archive_quote_delete(quote_id: int): - init_archive_db() - if not delete_review_quote(int(quote_id)): - raise HTTPException(status_code=404, detail="语录不存在") - return {"ok": True, "id": int(quote_id)} - - -class MacroEventBody(BaseModel): - event_type: str = "" - event_at: str = "" - note: str = "" - - -@app.get("/api/macro-calendar/meta") -def api_macro_calendar_meta(): - init_macro_calendar_db() - return { - "ok": True, - "event_types": [ - {"id": k, "label": MACRO_EVENT_LABELS[k]} for k in MACRO_EVENT_TYPES - ], - "window_before_minutes": 60, - "window_after_minutes": 60, - "timezone": "Asia/Shanghai", - } - - -@app.get("/api/macro-calendar/events") -def api_macro_calendar_events(): - init_macro_calendar_db() - rows = list_macro_events() - return {"ok": True, "events": rows, "count": len(rows)} - - -@app.get("/api/macro-calendar/active") -def api_macro_calendar_active(): - init_macro_calendar_db() - alerts = list_active_alerts() - return {"ok": True, "alerts": alerts, "count": len(alerts)} - - -@app.post("/api/macro-calendar/events") -def api_macro_calendar_create(body: MacroEventBody = Body(...)): - init_macro_calendar_db() - try: - row = create_macro_event(body.event_type, body.event_at, note=body.note) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) from e - return {"ok": True, "event": row} - - -@app.patch("/api/macro-calendar/events/{event_id}") -def api_macro_calendar_update(event_id: int, body: MacroEventBody = Body(...)): - init_macro_calendar_db() - try: - row = update_macro_event( - int(event_id), - event_type=body.event_type or None, - event_at=body.event_at or None, - note=body.note, - ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) from e - if not row: - raise HTTPException(status_code=404, detail="记录不存在") - return {"ok": True, "event": row} - - -@app.delete("/api/macro-calendar/events/{event_id}") -def api_macro_calendar_delete(event_id: int): - init_macro_calendar_db() - if not delete_macro_event(int(event_id)): - raise HTTPException(status_code=404, detail="记录不存在") - return {"ok": True, "id": int(event_id)} - - -@app.get("/api/archive/detail") -def api_archive_detail(exchange_key: str = "", symbol: str = ""): - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - if not ex_k or not sym: - raise HTTPException(status_code=400, detail="缺少 exchange_key 或 symbol") - init_archive_db() - trades = load_symbol_trades(ex_k, sym) - return {"ok": True, "exchange_key": ex_k, "symbol": sym, "trades": trades} - - -@app.get("/api/archive/ohlcv") -def api_archive_ohlcv( - exchange_key: str = "", - symbol: str = "", - timeframe: str = ARCHIVE_DEFAULT_TIMEFRAME, - mode: str = "hold", - anchor_ms: str = "", - opened_ms: str = "", - closed_ms: str = "", - range: str = "", - at: str = "", - bars: str = "", -): - ex_k = (exchange_key or "").strip().lower() - sym = (symbol or "").strip().upper() - if not ex_k or not sym: - raise HTTPException(status_code=400, detail="缺少 exchange_key 或 symbol") - init_archive_db() - anchor = _parse_anchor_ms(at, anchor_ms) - open_ms = _parse_anchor_ms("", opened_ms) - close_ms = _parse_anchor_ms("", closed_ms) - try: - bar_n = int(bars) if (bars or "").strip().isdigit() else ARCHIVE_VISIBLE_BARS_DEFAULT - except ValueError: - bar_n = ARCHIVE_VISIBLE_BARS_DEFAULT - result = resolve_archive_chart( - ex_k, - sym, - timeframe, - anchor_ms=anchor, - opened_ms=open_ms, - closed_ms=close_ms, - mode=mode, - bars=bar_n, - range_mode=(range or "").strip().lower() or "window", - ) - if not result.get("ok"): - raise HTTPException(status_code=404, detail=result.get("msg") or "无 K 线") - return result - - -class ArchiveOverlayBody(BaseModel): - behavior_tag: str = "" - note: str = "" - - -@app.patch("/api/archive/trade/{exchange_key}/{trade_id}") -def api_archive_trade_overlay( - exchange_key: str, - trade_id: int, - body: ArchiveOverlayBody = Body(...), -): - ex_k = (exchange_key or "").strip().lower() - if not ex_k: - raise HTTPException(status_code=400, detail="缺少 exchange_key") - init_archive_db() - out = upsert_trade_overlay( - ex_k, - int(trade_id), - behavior_tag=body.behavior_tag, - note=body.note, - ) - return {"ok": True, "overlay": out} - - -@app.delete("/api/archive/trade/{exchange_key}/{trade_id}") -def api_archive_trade_delete(exchange_key: str, trade_id: int): - from hub_symbol_archive_lib import delete_trade_from_archive - - ex_k = (exchange_key or "").strip().lower() - if not ex_k: - raise HTTPException(status_code=400, detail="缺少 exchange_key") - init_archive_db() - removed = delete_trade_from_archive(ex_k, int(trade_id)) - if not removed: - raise HTTPException(status_code=404, detail="档案中无该笔交易") - return {"ok": True, "exchange_key": ex_k, "trade_id": int(trade_id)} - - -@app.post("/api/archive/sync") -async def api_archive_sync(): - body = await _run_archive_sync_once() - return body - - -@app.get("/api/entry-plans/meta") -def api_entry_plans_meta(): - init_entry_plan_db() - exchanges = [] - for ex in enabled_exchanges(load_settings()): - exchanges.append( - { - "id": ex.get("id"), - "key": ex.get("key"), - "name": ex.get("name"), - } - ) - return {"ok": True, **entry_plan_meta_payload(exchanges)} - - -@app.get("/api/entry-plans") -def api_entry_plans_list(status: str = "active"): - init_entry_plan_db() - try: - rows = list_entry_plans(status=status) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) from e - return {"ok": True, "plans": rows, "count": len(rows), "status": status.strip().lower()} - - -@app.get("/api/entry-plans/stats") -def api_entry_plan_stats( - dimension: str = "symbol", - period: str = "all", - date_from: str = "", - date_to: str = "", -): - init_entry_plan_db() - try: - stats = compute_entry_plan_stats( - dimension=dimension, - period=period, - date_from=date_from, - date_to=date_to, - ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) from e - return {"ok": True, "stats": stats} - - -@app.get("/api/entry-plans/{plan_id}") -def api_entry_plan_detail(plan_id: int): - init_entry_plan_db() - row = get_entry_plan(int(plan_id)) - if not row: - raise HTTPException(status_code=404, detail="计划不存在") - return {"ok": True, "plan": row} - - -class EntryPlanBody(BaseModel): - plan_date: str = "" - exchange_key: str = "" - symbol: str = "" - plan_type: str = "" - trend_timeframe: str = "" - entry_timeframe: str = "" - direction: str = "" - target_level: str = "" - current_range: str = "" - entry_scheme: str = "" - result: str | None = None - pnl_amount: float | None = None - note: str = "" - - -@app.post("/api/entry-plans") -def api_entry_plan_create(body: EntryPlanBody = Body(...)): - init_entry_plan_db() - try: - row = create_entry_plan(body.model_dump(exclude_unset=True)) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) from e - return {"ok": True, "plan": row} - - -@app.patch("/api/entry-plans/{plan_id}") -def api_entry_plan_update(plan_id: int, body: EntryPlanBody = Body(...)): - init_entry_plan_db() - payload = body.model_dump(exclude_unset=True) - if not payload: - raise HTTPException(status_code=400, detail="无更新字段") - try: - row = update_entry_plan(int(plan_id), payload) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) from e - if not row: - raise HTTPException(status_code=404, detail="计划不存在") - return {"ok": True, "plan": row} - - -@app.delete("/api/entry-plans/{plan_id}") -def api_entry_plan_delete(plan_id: int): - init_entry_plan_db() - try: - ok = delete_entry_plan(int(plan_id)) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) from e - if not ok: - raise HTTPException(status_code=404, detail="计划不存在或已归档") - return {"ok": True, "id": int(plan_id)} - - -@app.get("/api/hub/fund-overview") -def api_hub_fund_overview(): - from hub_fund_history_lib import build_fund_overview - from hub_ai.config import trading_day_reset_hour - - settings = load_settings() - snap = board_store.snapshot_dict() - payload = build_fund_overview( - enabled_exchanges(settings), - board_rows=snap.get("rows") or [], - reset_hour=trading_day_reset_hour(), - updated_at=snap.get("updated_at"), - ) - return payload - - -@app.get("/api/ping") -def api_ping(): - return { - "ok": True, - "service": "manual-trading-hub", - "build": HUB_BUILD, - "trade_ui": False, - "features": ["monitor", "settings", "auth", "board_sse", "dashboard_sse", "archive", "dashboard", "funds", "macro_calendar"], - "board_poll_interval_sec": HUB_BOARD_POLL_INTERVAL, - "board_version": board_store.version, - "board_aggregating": board_store.aggregating, - "board_updated_at": (board_store.payload or {}).get("updated_at") - if isinstance(board_store.payload, dict) - else None, - "board_error": board_store.last_error, - "dashboard_poll_interval_sec": DASHBOARD_POLL_INTERVAL_SEC, - "dashboard_version": dashboard_store.version, - "dashboard_aggregating": dashboard_store.aggregating, - "dashboard_updated_at": (dashboard_store.payload or {}).get("updated_at") - if isinstance(dashboard_store.payload, dict) - else None, - "dashboard_error": dashboard_store.last_error, - "password_required": password_required(), - "env_disabled_ids": sorted(env_force_disabled_ids()), - "hub_disabled_ids_raw": (os.getenv("HUB_DISABLED_IDS") or ""), - } - - -@app.post("/api/trade/order/{exchange_id}") -@app.post("/api/trade/key/{exchange_id}") -@app.post("/api/trade/trend/preview/{exchange_id}") -@app.post("/api/trade/trend/execute/{exchange_id}") -async def api_trade_removed(exchange_id: str): - return _trade_removed_response() - - -@app.get("/api/trade/meta/{exchange_id}") -@app.get("/api/trade/trend/preview/{exchange_id}/{preview_id}") -async def api_trade_removed_get(exchange_id: str, preview_id: str = ""): - return _trade_removed_response() - - -def main(): - import uvicorn - - print( - f"manual-trading-hub start build={HUB_BUILD} listen={HUB_HOST}:{HUB_PORT}", - flush=True, - ) - uvicorn.run(app, host=HUB_HOST, port=HUB_PORT, log_level="info", access_log=False) - - -if __name__ == "__main__": - main() +""" +多账户交易中控:监控区 / 系统设置。 +聚合各实例监控数据与子代理 /status;下单请在各实例网页操作。 +""" +from __future__ import annotations + +import asyncio +import os +import sys +from contextlib import asynccontextmanager +from pathlib import Path + +_REPO_ROOT = Path(__file__).resolve().parent.parent +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from lib.hub.hub_kline_store import format_ohlcv_detail, resolve_chart_bars, retention_days +from lib.hub.hub_ohlcv_lib import ( + CHART_TIMEFRAME_ORDER, + CHART_TIMEFRAMES, + bar_limit_for_timeframe, + chart_chunk_limit, + chart_initial_limit, + chart_memory_cap, + retention_policy_meta, +) +from lib.hub.hub_volume_rank_lib import ( + TOP_N_DEFAULT, + _exchange_rank_row_stale, + cache_needs_refresh, + format_volume_quote, + get_cached_rank, + load_volume_rank_cache, + merge_exchange_rank, + rank_date_label, + save_volume_rank_cache, + seconds_until_next_reset, + volume_rank_reset_hour, +) +from lib.hub.hub_symbol_archive_lib import ( + ARCHIVE_DEFAULT_TIMEFRAME, + ARCHIVE_QUOTES_MAX, + ARCHIVE_SEED_LOOKBACK_DAYS, + ARCHIVE_SYNC_INTERVAL_SEC, + ARCHIVE_TIMEFRAMES, + ARCHIVE_TRADE_DAYS, + ARCHIVE_TRADE_LIMIT, + ARCHIVE_VISIBLE_BARS_DEFAULT, + create_review_quote, + delete_review_quote, + init_db as init_archive_db, + list_daily_trades, + list_archive_calendar, + list_review_quotes, + list_symbol_rows, + load_symbol_trades, + parse_wall_clock_ms, + resolve_archive_chart, + sync_exchange_symbol_archives, + today_trading_day, + update_review_quote, + upsert_trade_overlay, +) +from lib.hub.hub_entry_plan_lib import ( + compute_entry_plan_stats, + create_entry_plan, + delete_entry_plan, + get_entry_plan, + init_db as init_entry_plan_db, + list_entry_plans, + meta_payload as entry_plan_meta_payload, + update_entry_plan, +) +from lib.hub.hub_macro_calendar_lib import ( + MACRO_EVENT_LABELS, + MACRO_EVENT_TYPES, + create_event as create_macro_event, + delete_event as delete_macro_event, + init_db as init_macro_calendar_db, + list_active_alerts, + list_events as list_macro_events, + update_event as update_macro_event, +) +from env_load import load_hub_dotenv + +load_hub_dotenv() + +import httpx +from fastapi import Body, FastAPI, HTTPException, Request +from fastapi.responses import FileResponse, JSONResponse +from fastapi.staticfiles import StaticFiles +from pydantic import BaseModel, Field + +from settings_store import ( + enabled_exchanges, + env_force_disabled_ids, + load_settings, + normalize_display_prefs, + normalize_supervisor_settings, + save_settings, +) +from hub_web_auth import ( + SESSION_COOKIE, + SESSION_MAX_AGE_SEC, + clear_session_cookie, + cookie_secure_for_request, + create_session_token, + embed_allowed, + embed_frame_ancestors, + is_public_path, + password_required, + set_session_cookie, + validate_session_token, + expected_username, + verify_credentials, +) +from lib.hub.hub_sso import HUB_SSO_TTL_SEC, mint_hub_sso_token, safe_next_path +from url_public import browser_url, default_review_url, public_origin +from urllib.parse import urlencode + +from hub_board_cache import HUB_BOARD_POLL_INTERVAL, board_store +from hub_dashboard_cache import dashboard_store +from hub_dashboard import DASHBOARD_POLL_INTERVAL_SEC +from hub_supervisor_cache import supervisor_store +from hub_supervisor_lib import process_supervisor_tick, set_supervisor_notify_hook +from hub_ai.supervisor import make_supervisor_ai_reply_fn +from hub_ai.config import trading_day_reset_hour +from hub_chart_cache import ( + HUB_CHART_POLL_INTERVAL, + HUB_CHART_WATCH_TTL_SEC, + chart_poll_store, + parse_series_key, +) + +try: + from exchange_orders import symbols_match as _symbols_match +except ImportError: + + def _symbols_match(position_symbol: str, order_symbol: str) -> bool: + a = (position_symbol or "").strip().upper() + b = (order_symbol or "").strip().upper() + return bool(a and b and a == b) + +HUB_HOST = os.getenv("HUB_HOST", "0.0.0.0") +HUB_PORT = int(os.getenv("HUB_PORT", "5100")) +HUB_BRIDGE_TOKEN = (os.getenv("HUB_BRIDGE_TOKEN") or os.getenv("CONTROL_TOKEN") or "").strip() +_trust_raw = (os.getenv("HUB_TRUST_LAN", "true") or "").strip().lower() +HUB_TRUST_LAN = _trust_raw not in ("0", "false", "no", "off") +_allow_pub_raw = (os.getenv("HUB_ALLOW_PUBLIC") or "").strip().lower() +# 云服务器 + 域名反代时设为 true:不做 IP 限制,仅靠 HUB_PASSWORD / 登录页保护 +HUB_ALLOW_PUBLIC = _allow_pub_raw in ("1", "true", "yes", "on") +DIR = Path(__file__).resolve().parent +HUB_BUILD = "20260607-hub-archive" +_archive_sync_stop: asyncio.Event | None = None +_archive_sync_task: asyncio.Task | None = None +_last_archive_sync: dict | None = None +_volume_rank_stop: asyncio.Event | None = None +_volume_rank_task: asyncio.Task | None = None +_volume_rank_cache: dict | None = None +HUB_AGENT_TIMEOUT = float(os.getenv("HUB_AGENT_TIMEOUT", "8")) +HUB_FLASK_TIMEOUT = float(os.getenv("HUB_FLASK_TIMEOUT", "10")) +HUB_BOARD_TIMEOUT = float(os.getenv("HUB_BOARD_TIMEOUT", "45")) +_board_key_prices_raw = (os.getenv("HUB_BOARD_KEY_PRICES", "true") or "").strip().lower() +HUB_BOARD_KEY_PRICES = _board_key_prices_raw in ("1", "true", "yes", "on") + + +def _is_local(host: str | None) -> bool: + if not host: + return False + h = host.lower() + return h in ("127.0.0.1", "::1", "localhost") or h.startswith("::ffff:127.0.0.1") + + +def _ipv4_rfc1918_private(host: str) -> bool: + h = host.lower() + if h.startswith("::ffff:"): + h = h[7:] + parts = h.split(".") + if len(parts) != 4: + return False + try: + a, b, c, d = (int(x) for x in parts) + except ValueError: + return False + if any(x < 0 or x > 255 for x in (a, b, c, d)): + return False + if a == 10: + return True + if a == 172 and 16 <= b <= 31: + return True + if a == 192 and b == 168: + return True + return False + + +def _client_allowed(host: str | None) -> bool: + if _is_local(host): + return True + if HUB_TRUST_LAN and host and _ipv4_rfc1918_private(host): + return True + return False + + +def _hub_headers() -> dict[str, str]: + if not HUB_BRIDGE_TOKEN: + return {} + return {"X-Hub-Token": HUB_BRIDGE_TOKEN} + + +def _agent_headers() -> dict[str, str]: + if not HUB_BRIDGE_TOKEN: + return {} + return {"X-Control-Token": HUB_BRIDGE_TOKEN} + + +def _find_exchange(ex_id: str) -> dict | None: + for ex in load_settings().get("exchanges") or []: + if str(ex.get("id")) == str(ex_id): + return ex + return None + + +async def _run_chart_poll() -> dict: + keys = chart_poll_store.active_series_keys() + if not keys: + return {"ok": True, "series_count": 0, "polled": 0} + polled = 0 + errors: list[str] = [] + for key in keys: + parsed = parse_series_key(key) + if not parsed: + continue + ex_k, sym, tf = parsed + ex = _find_exchange_by_key(ex_k) + if not ex or not ex.get("enabled"): + continue + + ex_ref = ex + sym_ref = sym + tf_ref = tf + + def remote_fetch(**kwargs) -> dict: + tf_use = kwargs.get("timeframe") or tf_ref + return _fetch_instance_ohlcv_sync( + ex_ref, + symbol=kwargs.get("symbol") or sym_ref, + timeframe=tf_use, + since_ms=kwargs.get("since_ms"), + limit=int(kwargs.get("limit") or bar_limit_for_timeframe(tf_use)), + ) + + try: + result = await asyncio.to_thread( + resolve_chart_bars, + ex_k, + sym, + tf, + remote_fetch, + force_refresh=False, + tail_refresh=True, + ) + polled += 1 + chart_poll_store.note_series_result( + ex_k, + sym, + tf, + ok=bool(result.get("ok")), + fetched=int(result.get("fetched") or 0), + error=None if result.get("ok") else str(result.get("msg") or "poll_failed"), + candles=result.get("candles") if result.get("ok") else None, + price_tick=result.get("price_tick"), + ) + if not result.get("ok"): + errors.append(f"{key}:{result.get('msg')}") + except Exception as e: + chart_poll_store.note_series_result(ex_k, sym, tf, ok=False, error=str(e)) + errors.append(f"{key}:{e}") + out: dict = {"ok": True, "series_count": len(keys), "polled": polled} + if errors: + out["errors"] = errors[:8] + return out + + +async def _run_board_aggregate() -> dict: + try: + body = await asyncio.wait_for(_build_monitor_board_payload(), timeout=HUB_BOARD_TIMEOUT) + try: + from lib.hub.hub_fund_history_lib import record_fund_snapshot_from_board + + await asyncio.to_thread(record_fund_snapshot_from_board, body.get("rows") or []) + except Exception: + pass + return {"ok": True, **body} + except asyncio.TimeoutError: + return { + "ok": False, + "rows": [], + "error": "board_timeout", + "msg": ( + f"监控聚合超过 {int(HUB_BOARD_TIMEOUT)} 秒。" + "请检查子代理/Flask,或设 HUB_BOARD_KEY_PRICES=false、缩短 HUB_FLASK_TIMEOUT" + ), + "updated_at": __import__("datetime").datetime.now().isoformat(timespec="seconds"), + } + + +def _schedule_board_refresh() -> None: + board_store.request_refresh() + dashboard_store.request_refresh() + supervisor_store.request_refresh() + + +async def _run_archive_sync_once() -> dict: + global _last_archive_sync + init_archive_db() + settings = load_settings() + targets = enabled_exchanges(settings) + results: list[dict] = [] + for ex in targets: + ex_key = str(ex.get("key") or "").strip().lower() + if not ex_key: + continue + trades_resp = await asyncio.to_thread( + _fetch_instance_trades_archive_sync, + ex, + days=ARCHIVE_TRADE_DAYS, + limit=ARCHIVE_TRADE_LIMIT, + ) + if not trades_resp.get("ok"): + st = trades_resp.get("status") + msg = ( + trades_resp.get("msg") + or trades_resp.get("error") + or trades_resp.get("detail") + or "拉取交易失败" + ) + if st == 404: + msg = ( + "HTTP 404:该 Flask 未注册 /api/hub/trades/archive。" + "请在仓库根目录 git pull 后 pm2 restart crypto_gate crypto_gate_bot" + ) + results.append( + { + "exchange_key": ex_key, + "name": ex.get("name"), + "ok": False, + "status": st, + "msg": msg, + } + ) + continue + trades = trades_resp.get("trades") or [] + for t in trades: + if isinstance(t, dict): + t["exchange_key"] = ex_key + + def remote_fetch(**kwargs): + return _fetch_instance_ohlcv_sync( + ex, + symbol=kwargs.get("symbol") or "", + timeframe=kwargs.get("timeframe") or "5m", + since_ms=kwargs.get("since_ms"), + limit=int(kwargs.get("limit") or 500), + ) + + r = await asyncio.to_thread( + sync_exchange_symbol_archives, + ex_key, + trades, + remote_fetch, + ) + r["name"] = ex.get("name") + r["trade_count"] = len(trades) + results.append(r) + out = { + "ok": True, + "exchanges": len(targets), + "results": results, + "updated_at": __import__("datetime").datetime.now().isoformat(timespec="seconds"), + } + _last_archive_sync = out + return out + + +def _fetch_instance_volume_rank_sync(ex: dict, *, top_n: int = TOP_N_DEFAULT) -> dict: + base = (ex.get("flask_url") or "").rstrip("/") + if not base: + return {"ok": False, "msg": "未配置 flask_url"} + params = {"top": str(int(top_n))} + url = f"{base}/api/hub/volume-rank?{urlencode(params)}" + try: + with httpx.Client(timeout=max(HUB_FLASK_TIMEOUT, 120.0)) as client: + r = client.get(url, headers=_hub_headers()) + if r.status_code >= 400: + parsed = _parse_http_json_body(r) + parsed.setdefault("ok", False) + parsed.setdefault("status", r.status_code) + return parsed + data = r.json() if r.content else {} + return data if isinstance(data, dict) else {"ok": False, "msg": "无效 JSON"} + except Exception as e: + return {"ok": False, "msg": str(e)} + + +def _get_volume_rank_cache() -> dict: + global _volume_rank_cache + if _volume_rank_cache is None: + _volume_rank_cache = load_volume_rank_cache() + return _volume_rank_cache + + +def _refresh_volume_ranks(*, force: bool = False) -> dict: + global _volume_rank_cache + expected = rank_date_label() + cache = _get_volume_rank_cache() + targets = enabled_exchanges(load_settings()) + required_keys = [ + str(ex.get("key") or "").strip().lower() + for ex in targets + if ex.get("enabled") and str(ex.get("key") or "").strip() + ] + if not force and not cache_needs_refresh( + cache, expected_rank_date=expected, required_keys=required_keys + ): + return { + "ok": True, + "skipped": True, + "rank_date": cache.get("rank_date"), + "updated_at": cache.get("updated_at"), + } + errors: list[str] = [] + for ex in targets: + ex_key = str(ex.get("key") or "").strip().lower() + if not ex_key or not ex.get("enabled"): + continue + resp = _fetch_instance_volume_rank_sync(ex, top_n=TOP_N_DEFAULT) + if resp.get("ok") and resp.get("items"): + cache = merge_exchange_rank(cache, ex_key, resp) + else: + msg = str(resp.get("msg") or resp.get("error") or "拉取失败") + if resp.get("ok") and not resp.get("items"): + msg = msg if msg != "拉取失败" else "无有效成交额数据" + errors.append(f"{ex_key}:{msg}") + exchanges = dict(cache.get("exchanges") or {}) + prev = dict(exchanges.get(ex_key) or {}) + prev["error"] = msg + if not prev.get("items"): + prev["items"] = [] + exchanges[ex_key] = prev + cache["exchanges"] = exchanges + cache["rank_date"] = expected + save_volume_rank_cache(cache) + _volume_rank_cache = cache + out: dict = { + "ok": True, + "rank_date": expected, + "exchanges": len(targets), + "updated_at": cache.get("updated_at"), + } + if errors: + out["errors"] = errors[:8] + return out + + +async def _volume_rank_loop() -> None: + global _volume_rank_stop + stop = _volume_rank_stop + if stop is None: + return + try: + await asyncio.to_thread(_refresh_volume_ranks, force=False) + except Exception: + pass + while not stop.is_set(): + try: + wait_sec = seconds_until_next_reset() + await asyncio.wait_for(stop.wait(), timeout=wait_sec) + break + except asyncio.TimeoutError: + pass + if stop.is_set(): + break + try: + await asyncio.to_thread(_refresh_volume_ranks, force=True) + except Exception: + pass + + +async def _archive_sync_loop() -> None: + global _archive_sync_stop + stop = _archive_sync_stop + if stop is None: + return + init_archive_db() + while not stop.is_set(): + try: + await _run_archive_sync_once() + except Exception: + pass + try: + await asyncio.wait_for(stop.wait(), timeout=float(ARCHIVE_SYNC_INTERVAL_SEC)) + except asyncio.TimeoutError: + pass + + +async def _run_supervisor_tick() -> dict: + dash = dashboard_store.snapshot_dict() + board = board_store.snapshot_dict() + settings = load_settings() + ai_fn = make_supervisor_ai_reply_fn(_all_exchanges_for_ai()) + return await asyncio.to_thread( + process_supervisor_tick, + dash if dash.get("ok") is not False else None, + board if board.get("ok") is not False else None, + settings, + reset_hour=trading_day_reset_hour(), + ai_reply_fn=ai_fn, + ) + + +@asynccontextmanager +async def _hub_lifespan(_app: FastAPI): + global _archive_sync_stop, _archive_sync_task, _volume_rank_stop, _volume_rank_task + set_supervisor_notify_hook(supervisor_store.bump) + await board_store.start(_run_board_aggregate) + await dashboard_store.start(_run_dashboard_aggregate) + await supervisor_store.start(_run_supervisor_tick) + await chart_poll_store.start(_run_chart_poll) + _archive_sync_stop = asyncio.Event() + _archive_sync_task = asyncio.create_task(_archive_sync_loop(), name="hub-archive-sync") + _volume_rank_stop = asyncio.Event() + _volume_rank_task = asyncio.create_task(_volume_rank_loop(), name="hub-volume-rank") + try: + yield + finally: + if _archive_sync_stop: + _archive_sync_stop.set() + if _archive_sync_task: + _archive_sync_task.cancel() + try: + await _archive_sync_task + except asyncio.CancelledError: + pass + _archive_sync_task = None + _archive_sync_stop = None + if _volume_rank_stop: + _volume_rank_stop.set() + if _volume_rank_task: + _volume_rank_task.cancel() + try: + await _volume_rank_task + except asyncio.CancelledError: + pass + _volume_rank_task = None + _volume_rank_stop = None + await chart_poll_store.stop() + await supervisor_store.stop() + await dashboard_store.stop() + await board_store.stop() + set_supervisor_notify_hook(None) + + +app = FastAPI(title="复盘系统中控", docs_url=None, redoc_url=None, lifespan=_hub_lifespan) +STATIC_DIR = DIR / "static" +_REPO_STATIC = _REPO_ROOT / "lib" / "common" / "static" +_AI_REVIEW_RENDER_JS = _REPO_STATIC / "ai_review_render.js" +_TRADE_STATS_CALENDAR_CSS = _REPO_STATIC / "trade_stats_calendar.css" +_TRADE_STATS_CALENDAR_JS = _REPO_STATIC / "trade_stats_calendar.js" +_ACCOUNT_RISK_BADGE_CSS = _REPO_STATIC / "account_risk_badge.css" +_ACCOUNT_RISK_BADGE_JS = _REPO_STATIC / "account_risk_badge.js" + + +@app.get("/assets/account_risk_badge.css") +def hub_account_risk_badge_css(): + """与四所实例共用仓库根 static/account_risk_badge.css。""" + if not _ACCOUNT_RISK_BADGE_CSS.is_file(): + raise HTTPException(status_code=404, detail="account_risk_badge.css not found") + return FileResponse( + str(_ACCOUNT_RISK_BADGE_CSS), + media_type="text/css; charset=utf-8", + ) + + +@app.get("/assets/account_risk_badge.js") +def hub_account_risk_badge_js(): + """与四所实例共用仓库根 static/account_risk_badge.js。""" + if not _ACCOUNT_RISK_BADGE_JS.is_file(): + raise HTTPException(status_code=404, detail="account_risk_badge.js not found") + return FileResponse( + str(_ACCOUNT_RISK_BADGE_JS), + media_type="application/javascript; charset=utf-8", + ) + + +@app.get("/assets/ai_review_render.js") +def hub_ai_review_render_js(): + """与四所实例共用仓库根 static/ai_review_render.js(须在 /assets mount 之前注册)。""" + if not _AI_REVIEW_RENDER_JS.is_file(): + raise HTTPException(status_code=404, detail="ai_review_render.js not found") + return FileResponse( + str(_AI_REVIEW_RENDER_JS), + media_type="application/javascript; charset=utf-8", + ) + + +@app.get("/assets/trade_stats_calendar.css") +def hub_trade_stats_calendar_css(): + if not _TRADE_STATS_CALENDAR_CSS.is_file(): + raise HTTPException(status_code=404, detail="trade_stats_calendar.css not found") + return FileResponse( + str(_TRADE_STATS_CALENDAR_CSS), + media_type="text/css; charset=utf-8", + ) + + +@app.get("/assets/trade_stats_calendar.js") +def hub_trade_stats_calendar_js(): + if not _TRADE_STATS_CALENDAR_JS.is_file(): + raise HTTPException(status_code=404, detail="trade_stats_calendar.js not found") + return FileResponse( + str(_TRADE_STATS_CALENDAR_JS), + media_type="application/javascript; charset=utf-8", + ) + + +if STATIC_DIR.is_dir(): + app.mount("/assets", StaticFiles(directory=str(STATIC_DIR)), name="assets") + + +@app.middleware("http") +async def local_only(request: Request, call_next): + if HUB_ALLOW_PUBLIC: + return await call_next(request) + peer = request.client.host if request.client else None + if not _client_allowed(peer): + return JSONResponse({"detail": "forbidden"}, status_code=403) + return await call_next(request) + + +@app.middleware("http") +async def embed_frame_headers(request: Request, call_next): + response = await call_next(request) + if embed_allowed(): + ancestors = embed_frame_ancestors() + if ancestors == "*": + response.headers["Content-Security-Policy"] = "frame-ancestors *" + else: + response.headers["Content-Security-Policy"] = f"frame-ancestors 'self' {ancestors}" + return response + + +@app.middleware("http") +async def hub_password_gate(request: Request, call_next): + if not password_required(): + return await call_next(request) + path = request.url.path + if is_public_path(path, request.method): + return await call_next(request) + token = request.cookies.get(SESSION_COOKIE) + if validate_session_token(token): + return await call_next(request) + if path.startswith("/api/"): + return JSONResponse({"detail": "未登录", "login_required": True}, status_code=401) + from fastapi.responses import RedirectResponse + + nxt = path if path.startswith("/") else "/monitor" + return RedirectResponse(f"/login?next={nxt}", status_code=302) + + +def _shell_page(): + index = STATIC_DIR / "index.html" + if not index.is_file(): + return JSONResponse({"detail": "missing static/index.html"}, status_code=500) + return FileResponse(index) + + +def _login_page(): + login = STATIC_DIR / "login.html" + if not login.is_file(): + return JSONResponse({"detail": "missing static/login.html"}, status_code=500) + return FileResponse(login) + + +class LoginBody(BaseModel): + username: str = "" + password: str = "" + + +@app.get("/api/auth/status") +def api_auth_status(request: Request): + required = password_required() + logged_in = not required or validate_session_token(request.cookies.get(SESSION_COOKIE)) + return { + "required": required, + "logged_in": logged_in, + } + + +@app.post("/api/auth/login") +def api_auth_login(body: LoginBody, request: Request): + if not password_required(): + return {"ok": True, "auth_disabled": True} + if not verify_credentials(body.username, body.password): + raise HTTPException(status_code=401, detail="用户名或密码错误") + token = create_session_token(body.username) + embed = (request.headers.get("x-hub-embed") or "").strip() == "1" + resp = JSONResponse({"ok": True, "session_token": token, "embed": embed}) + set_session_cookie(resp, request, token, embed=embed) + return resp + + +@app.get("/embed-auth") +def embed_auth_login(request: Request, token: str = "", next: str = "/monitor"): + """ + 嵌入式打开:父页跨域 fetch 登录时 Cookie 可能写不进 iframe, + 用 session_token 在本页做一次导航,在 iframe 内写入 hub_sess。 + """ + from fastapi.responses import RedirectResponse + + dest = safe_next_path(next) + if not password_required(): + return RedirectResponse(dest, status_code=302) + if not validate_session_token(token): + q = urlencode({"next": dest, "embed": "1"}) + return RedirectResponse(f"/login?{q}", status_code=302) + resp = RedirectResponse(dest, status_code=302) + set_session_cookie(resp, request, token, embed=True) + return resp + + +@app.post("/api/auth/logout") +def api_auth_logout(request: Request): + embed = (request.headers.get("x-hub-embed") or "").strip() == "1" + resp = JSONResponse({"ok": True}) + clear_session_cookie(resp, request, embed=embed) + return resp + + +@app.get("/login") +def login_page(): + return _login_page() + + +@app.get("/") +def root_redirect(): + from fastapi.responses import RedirectResponse + + return RedirectResponse("/monitor") + + +@app.get("/monitor") +@app.get("/plan") +@app.get("/calculator") +@app.get("/market") +@app.get("/archive") +@app.get("/dashboard") +@app.get("/funds") +@app.get("/ai") +@app.get("/settings") +def shell_pages(): + return _shell_page() + + +def _all_exchanges_for_ai() -> list: + """AI 聚合用:含未启用账户(标记未监控)。""" + return list(load_settings().get("exchanges") or []) + + +from hub_ai.routes import create_hub_ai_router +from hub_dashboard import build_dashboard_payload, default_trading_day + +app.include_router(create_hub_ai_router(load_all_exchanges=_all_exchanges_for_ai)) + + +async def _run_dashboard_aggregate() -> dict: + try: + return await asyncio.to_thread( + build_dashboard_payload, + enabled_exchanges(), + trading_day=default_trading_day(), + ) + except Exception as exc: + return {"ok": False, "msg": str(exc), "error": "aggregate_failed"} + + +def _schedule_dashboard_refresh() -> None: + dashboard_store.request_refresh() + supervisor_store.request_refresh() + + +@app.get("/api/dashboard/daily") +def api_dashboard_daily(trading_day: str = ""): + day = (trading_day or "").strip()[:10] or default_trading_day() + if not (trading_day or "").strip(): + return dashboard_store.snapshot_dict() + try: + payload = build_dashboard_payload( + enabled_exchanges(), + trading_day=day, + ) + except Exception as exc: + raise HTTPException(status_code=502, detail=str(exc)) from exc + return {**payload, "dashboard_version": dashboard_store.version} + + +@app.get("/api/dashboard/stream") +async def api_dashboard_stream(): + from fastapi.responses import StreamingResponse + + return StreamingResponse( + dashboard_store.iter_sse(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + +@app.post("/api/dashboard/refresh") +async def api_dashboard_refresh(): + _schedule_dashboard_refresh() + return {"ok": True, "dashboard_version": dashboard_store.version} + + +@app.get("/api/ai/supervisor/stream") +async def api_supervisor_stream(): + from fastapi.responses import StreamingResponse + + return StreamingResponse( + supervisor_store.iter_sse(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + +@app.post("/api/ai/supervisor/refresh") +async def api_supervisor_refresh(): + supervisor_store.request_refresh() + return {"ok": True, "supervisor_version": supervisor_store.version} + + +@app.get("/trade") +def trade_removed_redirect(): + from fastapi.responses import RedirectResponse + + return RedirectResponse("/monitor", status_code=302) + + +@app.get("/api/settings") +def api_get_settings(): + return load_settings() + + +class SettingsDisplayBody(BaseModel): + show_account_pnl: bool = True + show_nav_funds: bool = True + show_nav_dashboard: bool = True + show_nav_plan: bool = True + show_nav_archive: bool = True + show_nav_ai: bool = True + show_nav_calculator: bool = True + + +class SupervisorSettingsBody(BaseModel): + enabled: bool = True + wechat_webhook: str = "" + wechat_link_base: str = "" + wechat_prefix: str = "【交易监管】" + wechat_on_program_tp_sl: bool = True + manual_close_daily_warn: int = 2 + interval_warn_minutes: int = 15 + freq_30m_count: int = 2 + reopen_after_close_minutes: int = 30 + + +class SettingsBody(BaseModel): + exchanges: list[dict] = Field(default_factory=list) + display: SettingsDisplayBody | None = None + supervisor: SupervisorSettingsBody | None = None + + +@app.post("/api/settings") +def api_save_settings(body: SettingsBody): + force_off = env_force_disabled_ids() + to_save = [] + for ex in body.exchanges: + row = dict(ex) + eid = str(row.get("id", "")).strip() + if eid in force_off: + row["enabled"] = False + row.pop("env_disabled", None) + to_save.append(row) + existing = load_settings() + display = normalize_display_prefs(existing.get("display")) + if body.display is not None: + display = normalize_display_prefs(body.display.model_dump()) + supervisor = normalize_supervisor_settings(existing.get("supervisor")) + if body.supervisor is not None: + supervisor = normalize_supervisor_settings(body.supervisor.model_dump()) + save_settings({"version": 1, "exchanges": to_save, "display": display, "supervisor": supervisor}) + return {"ok": True, "settings": load_settings()} + + +class TrendCalculatorBody(BaseModel): + direction: str = "long" + capital_usdt: float = Field(gt=0) + risk_percent: float = Field(gt=0, le=100) + leverage: int = Field(ge=1, le=125) + entry_price: float = Field(gt=0) + stop_loss: float = Field(gt=0) + add_upper: float = Field(gt=0) + take_profit: float = Field(gt=0) + dca_legs: int = Field(default=5, ge=1, le=20) + exchange_id: str = "0" + base: str = "ETH" + + +class RollAddLegBody(BaseModel): + add_price: float = Field(gt=0) + new_stop_loss: float = Field(gt=0) + + +class RollCalculatorBody(BaseModel): + direction: str = "long" + capital_usdt: float = Field(gt=0) + risk_percent: float = Field(gt=0, le=100) + entry_price: float = Field(gt=0) + stop_loss: float = Field(gt=0) + take_profit: float = Field(gt=0) + add_legs: list[RollAddLegBody] = Field(default_factory=list, max_length=3) + legs_done: int = Field(default=0, ge=0, le=3) + exchange_id: str = "0" + base: str = "ETH" + + +@app.get("/api/calculator/exchanges") +def api_calculator_exchanges(): + from lib.hub.hub_calculator_market_lib import list_calculator_exchanges + + return {"ok": True, "data": list_calculator_exchanges()} + + +@app.get("/api/calculator/market") +def api_calculator_market(exchange_id: str = "0", base: str = "ETH"): + from lib.hub.hub_calculator_market_lib import get_calculator_market + + data, err = get_calculator_market(exchange_id, base) + if err: + return JSONResponse({"ok": False, "msg": err}, status_code=400) + return {"ok": True, "data": data} + + +@app.post("/api/calculator/trend") +def api_calculator_trend(body: TrendCalculatorBody): + from lib.hub.hub_calculator_lib import calc_trend_calculator + + data, err = calc_trend_calculator( + direction=body.direction, + capital_usdt=body.capital_usdt, + risk_percent=body.risk_percent, + leverage=body.leverage, + entry_price=body.entry_price, + stop_loss=body.stop_loss, + add_upper=body.add_upper, + take_profit=body.take_profit, + dca_legs=body.dca_legs, + exchange_id=body.exchange_id, + base=body.base, + ) + if err: + return JSONResponse({"ok": False, "msg": err}, status_code=400) + return {"ok": True, "data": data} + + +@app.post("/api/calculator/roll") +def api_calculator_roll(body: RollCalculatorBody): + from lib.hub.hub_calculator_lib import calc_roll_calculator + + data, err = calc_roll_calculator( + direction=body.direction, + capital_usdt=body.capital_usdt, + risk_percent=body.risk_percent, + entry_price=body.entry_price, + stop_loss=body.stop_loss, + take_profit=body.take_profit, + add_legs=[leg.model_dump() for leg in body.add_legs], + legs_done=body.legs_done, + exchange_id=body.exchange_id, + base=body.base, + ) + if err: + return JSONResponse({"ok": False, "msg": err}, status_code=400) + return {"ok": True, "data": data} + + +def _find_exchange_by_key(exchange_key: str) -> dict | None: + key = (exchange_key or "").strip().lower() + if not key: + return None + for ex in load_settings().get("exchanges") or []: + if str(ex.get("key") or "").strip().lower() == key: + return ex + if str(ex.get("id") or "").strip() == exchange_key.strip(): + return ex + return None + + +def _fetch_instance_trades_archive_sync( + ex: dict, + *, + days: int = 365, + limit: int = 2000, +) -> dict: + base = (ex.get("flask_url") or "").rstrip("/") + if not base: + return {"ok": False, "msg": "未配置 flask_url"} + params = {"days": str(int(days)), "limit": str(int(limit))} + url = f"{base}/api/hub/trades/archive?{urlencode(params)}" + try: + with httpx.Client(timeout=HUB_FLASK_TIMEOUT) as client: + r = client.get(url, headers=_hub_headers()) + if r.status_code >= 400: + parsed = _parse_http_json_body(r) + parsed.setdefault("ok", False) + parsed.setdefault("status", r.status_code) + return parsed + data = r.json() if r.content else {} + if isinstance(data, dict): + data.setdefault("ok", True) + return data + return {"ok": False, "msg": "无效 JSON"} + except Exception as e: + return {"ok": False, "msg": str(e)} + + +def _fetch_instance_ohlcv_sync( + ex: dict, + *, + symbol: str, + timeframe: str, + since_ms: int | None, + limit: int, +) -> dict: + base = (ex.get("flask_url") or "").rstrip("/") + if not base: + return {"ok": False, "msg": "未配置 flask_url"} + params = {"symbol": symbol, "timeframe": timeframe, "limit": str(int(limit))} + if since_ms is not None and int(since_ms) > 0: + params["since_ms"] = str(int(since_ms)) + url = f"{base}/api/hub/ohlcv?{urlencode(params)}" + try: + with httpx.Client(timeout=HUB_FLASK_TIMEOUT) as client: + r = client.get(url, headers=_hub_headers()) + if r.status_code >= 400: + parsed = _parse_http_json_body(r) + parsed.setdefault("ok", False) + return parsed + data = r.json() if r.content else {} + return data if isinstance(data, dict) else {"ok": False, "msg": "无效 JSON"} + except Exception as e: + return {"ok": False, "msg": str(e)} + + +@app.get("/api/chart/meta") +def api_chart_meta(): + tfs = [tf for tf in CHART_TIMEFRAME_ORDER if tf in CHART_TIMEFRAMES] + exchanges = [] + for ex in enabled_exchanges(load_settings()): + exchanges.append( + { + "id": ex.get("id"), + "key": ex.get("key"), + "name": ex.get("name"), + } + ) + return { + "ok": True, + "timeframes": [tf for tf in tfs if tf in CHART_TIMEFRAMES], + "retention_days": retention_days(), + "retention_policy": retention_policy_meta(), + "limits": {tf: bar_limit_for_timeframe(tf) for tf in tfs if tf in CHART_TIMEFRAMES}, + "initial_limits": {tf: chart_initial_limit(tf) for tf in tfs if tf in CHART_TIMEFRAMES}, + "chunk_limits": {tf: chart_chunk_limit(tf) for tf in tfs if tf in CHART_TIMEFRAMES}, + "memory_caps": {tf: chart_memory_cap(tf) for tf in tfs if tf in CHART_TIMEFRAMES}, + "exchanges": exchanges, + "volume_rank_top_n": TOP_N_DEFAULT, + "volume_rank_reset_hour": volume_rank_reset_hour(), + } + + +@app.get("/api/chart/volume-rank") +def api_chart_volume_rank(exchange_key: str = "", refresh: str = ""): + force = (refresh or "").strip().lower() in ("1", "true", "yes", "on") + if force: + result = _refresh_volume_ranks(force=True) + if not result.get("ok"): + raise HTTPException(status_code=502, detail=result.get("msg") or "刷新失败") + cache = _get_volume_rank_cache() + ex_k = (exchange_key or "").strip().lower() + targets = enabled_exchanges(load_settings()) + required_keys = [ + str(ex.get("key") or "").strip().lower() + for ex in targets + if ex.get("enabled") and str(ex.get("key") or "").strip() + ] + need_keys = [ex_k] if ex_k else required_keys + if cache_needs_refresh(cache, required_keys=need_keys): + _refresh_volume_ranks(force=True) + cache = _get_volume_rank_cache() + elif ex_k: + row = (cache.get("exchanges") or {}).get(ex_k) or {} + if _exchange_rank_row_stale(row): + _refresh_volume_ranks(force=True) + cache = _get_volume_rank_cache() + if ex_k: + ex = _find_exchange_by_key(ex_k) + if not ex: + raise HTTPException(status_code=400, detail="交易所不存在") + payload = get_cached_rank(cache, ex_k, top_n=TOP_N_DEFAULT) + payload["items"] = [ + {**row, "volume_label": format_volume_quote(row.get("volume_quote"))} + for row in payload.get("items") or [] + ] + payload["reset_hour"] = volume_rank_reset_hour() + err = ((cache.get("exchanges") or {}).get(ex_k) or {}).get("error") + if err and not payload.get("items"): + payload["ok"] = False + payload["msg"] = err + return payload + exchanges_out = {} + for ex in enabled_exchanges(load_settings()): + key = str(ex.get("key") or "").strip().lower() + if not key: + continue + row = get_cached_rank(cache, key, top_n=TOP_N_DEFAULT) + row["name"] = ex.get("name") + row["items"] = [ + {**item, "volume_label": format_volume_quote(item.get("volume_quote"))} + for item in row.get("items") or [] + ] + exchanges_out[key] = row + return { + "ok": True, + "rank_date": cache.get("rank_date"), + "updated_at": cache.get("updated_at"), + "reset_hour": volume_rank_reset_hour(), + "exchanges": exchanges_out, + } + + +@app.post("/api/chart/volume-rank/refresh") +async def api_chart_volume_rank_refresh(): + result = await asyncio.to_thread(_refresh_volume_ranks, force=True) + if not result.get("ok"): + raise HTTPException(status_code=502, detail=result.get("msg") or "刷新失败") + return result + + +@app.get("/api/chart/ohlcv") +def api_chart_ohlcv( + exchange_key: str = "", + symbol: str = "", + timeframe: str = "1d", + refresh: str = "", + tail: str = "", + limit: int = 0, + before_ms: str = "", +): + ex = _find_exchange_by_key(exchange_key) + if not ex: + raise HTTPException(status_code=400, detail="交易所不存在") + if not ex.get("enabled"): + raise HTTPException(status_code=400, detail="该交易所未启用") + sym = (symbol or "").strip().upper() + if not sym: + raise HTTPException(status_code=400, detail="请输入币种") + ex_key = str(ex.get("key") or "").strip().lower() + force = (refresh or "").strip().lower() in ("1", "true", "yes", "on") + tail_refresh = (tail or "").strip().lower() in ("1", "true", "yes", "on") + lim = int(limit) if int(limit or 0) > 0 else None + bms_raw = (before_ms or "").strip() + bms = None + if bms_raw: + try: + bms = int(bms_raw) + except ValueError: + raise HTTPException(status_code=400, detail="before_ms 无效") + clear_db = force and not tail_refresh and bms is None + + def remote_fetch(**kwargs): + tf_use = kwargs.get("timeframe") or timeframe + return _fetch_instance_ohlcv_sync( + ex, + symbol=kwargs.get("symbol") or sym, + timeframe=tf_use, + since_ms=kwargs.get("since_ms"), + limit=int(kwargs.get("limit") or bar_limit_for_timeframe(tf_use)), + ) + + result = resolve_chart_bars( + ex_key, + sym, + timeframe, + remote_fetch, + force_refresh=force, + tail_refresh=tail_refresh, + clear_db=clear_db, + limit=lim, + before_ms=bms, + ) + if not result.get("ok"): + raise HTTPException(status_code=502, detail=result.get("msg") or "K线加载失败") + if not result.get("candles") and result.get("before_ms") is None: + raise HTTPException(status_code=502, detail=result.get("msg") or "无 K 线") + tick = result.get("price_tick") + last = result["candles"][-1] if result.get("candles") else None + result["ohlcv"] = format_ohlcv_detail( + { + "open": last.get("open") if last else None, + "high": last.get("high") if last else None, + "low": last.get("low") if last else None, + "close": last.get("close") if last else None, + "volume": last.get("volume") if last else None, + } + if last + else None, + tick, + ) + result["chart_version"] = chart_poll_store.version + result["series_version"] = chart_poll_store.series_version(ex_key, sym, timeframe) + result["chart_poll_interval_sec"] = HUB_CHART_POLL_INTERVAL + return result + + +class ChartWatchBody(BaseModel): + exchange_key: str = "" + symbol: str = "" + timeframe: str = "5m" + + +@app.post("/api/chart/watch") +async def api_chart_watch(body: ChartWatchBody = Body(...)): + ex_k = (body.exchange_key or "").strip().lower() + sym = (body.symbol or "").strip().upper() + tf = (body.timeframe or "5m").strip() + if not ex_k or not sym: + raise HTTPException(status_code=400, detail="缺少 exchange_key 或 symbol") + if tf not in CHART_TIMEFRAMES: + raise HTTPException(status_code=400, detail="不支持的周期") + key = chart_poll_store.touch_watch(ex_k, sym, tf) + chart_poll_store.request_refresh() + return { + "ok": True, + "series_key": key, + "series_version": chart_poll_store.series_version(ex_k, sym, tf), + "chart_version": chart_poll_store.version, + "watch_ttl_sec": HUB_CHART_WATCH_TTL_SEC, + } + + +@app.post("/api/chart/unwatch") +async def api_chart_unwatch(body: ChartWatchBody = Body(...)): + chart_poll_store.clear_watch(body.exchange_key, body.symbol, body.timeframe) + return {"ok": True} + + +@app.get("/api/chart/stream") +async def api_chart_stream(): + from fastapi.responses import StreamingResponse + + return StreamingResponse( + chart_poll_store.iter_sse(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + +@app.get("/api/chart/poll/meta") +async def api_chart_poll_meta(): + return chart_poll_store.event_dict() + + +@app.get("/api/settings/meta") +def api_settings_meta(): + po = public_origin() + return { + "env_disabled_ids": sorted(env_force_disabled_ids()), + "hub_bridge_token_set": bool(HUB_BRIDGE_TOKEN), + "capability_options": ["key", "trend"], + "public_origin": f"{po[0]}://{po[1]}" if po else None, + "public_origin_hint": ( + "未设置 HUB_PUBLIC_ORIGIN 时,复盘链接若为 127.0.0.1,仅服务器本机浏览器可打开" + if not po + else "复盘/展示链接已替换为对外地址" + ), + "password_required": password_required(), + } + + +async def _fetch_agent_status(client: httpx.AsyncClient, ex: dict) -> dict: + url = f"{ex['agent_url'].rstrip('/')}/status" + try: + r = await client.get(url, headers=_agent_headers(), timeout=HUB_AGENT_TIMEOUT) + body = r.json() if r.content else {} + return { + "id": ex["id"], + "name": ex["name"], + "key": ex.get("key"), + "agent_url": ex["agent_url"], + "flask_url": ex.get("flask_url"), + "capabilities": ex.get("capabilities") or [], + "http_ok": r.status_code == 200, + "agent": body, + "error": body.get("error") if isinstance(body, dict) else None, + } + except Exception as e: + return { + "id": ex["id"], + "name": ex["name"], + "key": ex.get("key"), + "agent_url": ex["agent_url"], + "flask_url": ex.get("flask_url"), + "capabilities": ex.get("capabilities") or [], + "http_ok": False, + "error": str(e), + "agent": None, + } + + +def _parse_http_json_body(r: httpx.Response) -> dict: + text = (r.text or "").strip() + if not text: + return {"ok": False, "status": r.status_code, "text": "(empty body)"} + try: + data = r.json() + if isinstance(data, dict): + return data + return {"ok": False, "status": r.status_code, "text": text[:500]} + except Exception: + snippet = text[:500] + if snippet.lstrip().lower().startswith(" dict | None: + base = (ex.get("flask_url") or "").rstrip("/") + if not base: + return None + try: + if method == "GET": + r = await client.get(f"{base}{path}", headers=_hub_headers(), timeout=HUB_FLASK_TIMEOUT) + else: + headers = {**_hub_headers(), "Content-Type": "application/json"} + if json_body is not None: + r = await client.post( + f"{base}{path}", headers=headers, json=json_body, timeout=120.0 + ) + else: + r = await client.post( + f"{base}{path}", headers=headers, data=data, timeout=120.0 + ) + if r.status_code >= 400: + parsed = _parse_http_json_body(r) + parsed.setdefault("ok", False) + parsed.setdefault("status", r.status_code) + return parsed + return _parse_http_json_body(r) + except Exception as e: + return {"ok": False, "error": str(e)} + + +async def _notify_instance_user_close( + client: httpx.AsyncClient, ex: dict, *, count: int = 1 +) -> dict | None: + """登记实例侧用户主动平仓风控(中控点平仓/全平)。""" + if count <= 0 or not (ex.get("flask_url") or "").strip(): + return None + return await _fetch_flask_json( + client, + ex, + "/api/hub/account-risk/user-close", + method="POST", + json_body={"source": "user_hub", "count": int(count)}, + ) + + +def _flask_error_from_hub_mon(hub_mon: dict | None) -> str | None: + if not isinstance(hub_mon, dict) or hub_mon.get("ok") is not False: + return None + st = hub_mon.get("status") + if st == 404: + return ( + "HTTP 404:该 Flask 未注册 /api/hub/*(hub_bridge 未加载)。" + "请在仓库根目录 git pull 后 pm2 restart crypto_binance crypto_gate crypto_gate_bot," + "并查看启动日志是否含 [hub_bridge] ImportError" + ) + return ( + hub_mon.get("msg") + or hub_mon.get("error") + or (f"HTTP {st}" if st else None) + or (str(hub_mon.get("text") or "")[:120] or None) + ) + + +def _cond_order_trigger_key(price: object) -> str | None: + if price is None or price == "": + return None + try: + return f"{float(price):.12g}" + except (TypeError, ValueError): + return None + + +def _merge_conditional_orders_no_dup( + existing: list, extra: list +) -> list: + """子代理已拉到的条件单与 Flask exchange_tpsl 合成行按触发价/订单号去重,避免 Gate 显示 4 笔实为 2 笔。""" + if not extra: + return list(existing) if existing else [] + if not existing: + return list(extra) + triggers: set[str] = set() + order_ids: set[str] = set() + out: list = [] + for row in existing: + if not isinstance(row, dict): + continue + out.append(row) + k = _cond_order_trigger_key(row.get("trigger_price")) + if k: + triggers.add(k) + oid = row.get("id") + if oid not in (None, ""): + order_ids.add(str(oid)) + for row in extra: + if not isinstance(row, dict): + continue + k = _cond_order_trigger_key(row.get("trigger_price")) + oid = row.get("id") + if k and k in triggers: + continue + if oid not in (None, "") and str(oid) in order_ids: + continue + out.append(row) + if k: + triggers.add(k) + if oid not in (None, ""): + order_ids.add(str(oid)) + return out + + +def _tpsl_slots_to_conditional_orders(exchange_tpsl: dict, symbol: str) -> list[dict]: + """将实例 price_snapshot 的 exchange_tpsl 转为中控条件单结构。""" + out: list[dict] = [] + if not isinstance(exchange_tpsl, dict): + return out + for role, label in (("sl", "止损"), ("tp", "止盈")): + slot = exchange_tpsl.get(role) + if not isinstance(slot, dict): + continue + trig = slot.get("trigger_price") + if trig is None: + continue + try: + trig_f = float(trig) + except (TypeError, ValueError): + continue + oid = slot.get("order_id") + out.append( + { + "id": str(oid) if oid is not None else "", + "symbol": symbol, + "channel": "algo", + "category": "conditional", + "label": f"{label} {trig_f:g}", + "trigger_price": trig_f, + "amount": slot.get("amount"), + "status": "open", + } + ) + return out + + +def _exchange_tpsl_from_hub_order(hub_orders: list, symbol: str, side: str) -> dict | None: + """趋势保本移交后:用下单监控计划价补全 exchange_tpsl(与实例页一致)。""" + side_l = (side or "").lower() + for o in hub_orders: + if not isinstance(o, dict): + continue + o_sym = o.get("exchange_symbol") or o.get("symbol") or "" + if not _symbols_match(symbol, o_sym): + continue + if (o.get("direction") or "").lower() != side_l: + continue + sl = o.get("stop_loss") + tp = o.get("take_profit") + if sl in (None, "") and tp in (None, ""): + continue + slots: dict = {"sl": None, "tp": None} + if sl not in (None, ""): + try: + slots["sl"] = {"trigger_price": float(sl), "order_id": None} + except (TypeError, ValueError): + pass + if tp not in (None, ""): + try: + slots["tp"] = {"trigger_price": float(tp), "order_id": None} + except (TypeError, ValueError): + pass + if slots["sl"] or slots["tp"]: + return slots + return None + + +def _find_exchange_tpsl_for_position( + symbol: str, + side: str, + order_prices: list, + hub_orders: list, +) -> dict | None: + side_l = (side or "").lower() + op_by_id = { + op.get("id"): op + for op in order_prices + if isinstance(op, dict) and op.get("id") is not None + } + for o in hub_orders: + if not isinstance(o, dict): + continue + o_sym = o.get("exchange_symbol") or o.get("symbol") or "" + if not _symbols_match(symbol, o_sym): + continue + if (o.get("direction") or "").lower() != side_l: + continue + op = op_by_id.get(o.get("id")) + if not isinstance(op, dict): + continue + et = op.get("exchange_tpsl") + if isinstance(et, dict) and (et.get("sl") or et.get("tp")): + return et + for op in order_prices: + if not isinstance(op, dict): + continue + if not _symbols_match(symbol, op.get("symbol") or ""): + continue + et = op.get("exchange_tpsl") + if isinstance(et, dict) and (et.get("sl") or et.get("tp")): + return et + return None + + +def _merge_flask_order_price_fields(hub_mon: dict | None, snap: dict | None) -> None: + """将 price_snapshot 中的快照盈亏比、已保本状态合并进 hub_monitor.orders。""" + if not isinstance(hub_mon, dict) or not isinstance(snap, dict): + return + order_prices = snap.get("order_prices") or [] + op_by_id = { + op.get("id"): op + for op in order_prices + if isinstance(op, dict) and op.get("id") is not None + } + orders = hub_mon.get("orders") or [] + if not isinstance(orders, list): + return + for o in orders: + if not isinstance(o, dict): + continue + op = op_by_id.get(o.get("id")) + if not isinstance(op, dict): + continue + if op.get("rr_ratio") is not None: + o["rr_ratio"] = op["rr_ratio"] + if "sl_breakeven_secured" in op: + o["sl_breakeven_secured"] = bool(op["sl_breakeven_secured"]) + for key in ( + "stop_loss", + "take_profit", + "stop_loss_display", + "take_profit_display", + "display_rr_ratio", + "exchange_initial_margin", + "plan_margin", + "time_close_enabled", + "time_close_hours", + "time_close_at_ms", + "time_close_label", + "time_close_countdown", + "time_close_remaining_sec", + ): + if key in op and op[key] not in (None, ""): + o[key] = op[key] + + +def _merge_flask_position_breakeven(agent_row: dict, snap: dict | None, hub_mon: dict | None) -> None: + """将 price_snapshot 的已保本状态同步到 agent 持仓,供中控首页表格展示。""" + ag = agent_row.get("agent") + if not isinstance(ag, dict) or not isinstance(snap, dict): + return + positions = ag.get("positions") + if not isinstance(positions, list) or not positions: + return + order_prices = snap.get("order_prices") or [] + hub_orders = [] + if isinstance(hub_mon, dict): + hub_orders = hub_mon.get("orders") or [] + op_by_id = { + op.get("id"): op + for op in order_prices + if isinstance(op, dict) and op.get("id") is not None + } + for p in positions: + if not isinstance(p, dict): + continue + sym = p.get("symbol") or "" + side = (p.get("side") or "").lower() + matched = None + for o in hub_orders: + if not isinstance(o, dict): + continue + o_sym = o.get("exchange_symbol") or o.get("symbol") or "" + if not _symbols_match(sym, o_sym): + continue + if (o.get("direction") or "").lower() != side: + continue + matched = op_by_id.get(o.get("id")) + break + if matched is None: + for op in order_prices: + if not isinstance(op, dict): + continue + if not _symbols_match(sym, op.get("symbol") or ""): + continue + matched = op + break + if isinstance(matched, dict) and "sl_breakeven_secured" in matched: + p["sl_breakeven_secured"] = bool(matched["sl_breakeven_secured"]) + + +def _agent_position_has_mark(p: dict) -> bool: + try: + v = float(p.get("mark_price")) + return v > 0 + except (TypeError, ValueError): + return False + + +def _apply_agent_mark_price(p: dict, mark_price: object, mark_display: object = None) -> None: + try: + mpf = float(mark_price) + except (TypeError, ValueError): + return + if mpf <= 0: + return + p["mark_price"] = mpf + disp = mark_display + if disp is not None and str(disp).strip() not in ("", "-"): + p["mark_price_fmt"] = str(disp) + + +def _find_matched_order_price_op( + p: dict, + order_prices: list, + hub_orders: list, + op_by_id: dict, +) -> dict | None: + sym = p.get("symbol") or "" + side = (p.get("side") or "").lower() + for o in hub_orders: + if not isinstance(o, dict): + continue + o_sym = o.get("exchange_symbol") or o.get("symbol") or "" + if not _symbols_match(sym, o_sym): + continue + if (o.get("direction") or "").lower() != side: + continue + matched = op_by_id.get(o.get("id")) + if isinstance(matched, dict): + return matched + break + for op in order_prices: + if not isinstance(op, dict): + continue + if not _symbols_match(sym, op.get("symbol") or ""): + continue + return op + return None + + +def _merge_flask_position_mark_price( + agent_row: dict, snap: dict | None, hub_mon: dict | None +) -> None: + """子代理无标记价时,用实例 price_snapshot 的交易所标记价补全中控持仓展示。""" + ag = agent_row.get("agent") + if not isinstance(ag, dict) or not isinstance(snap, dict): + return + positions = ag.get("positions") + if not isinstance(positions, list) or not positions: + return + order_prices = snap.get("order_prices") or [] + hub_orders = [] + if isinstance(hub_mon, dict): + hub_orders = hub_mon.get("orders") or [] + op_by_id = { + op.get("id"): op + for op in order_prices + if isinstance(op, dict) and op.get("id") is not None + } + for p in positions: + if not isinstance(p, dict) or _agent_position_has_mark(p): + continue + matched = _find_matched_order_price_op(p, order_prices, hub_orders, op_by_id) + if isinstance(matched, dict): + _apply_agent_mark_price( + p, + matched.get("exchange_mark_price"), + matched.get("exchange_mark_price_display"), + ) + position_marks = snap.get("position_marks") or [] + if not isinstance(position_marks, list): + return + for p in positions: + if not isinstance(p, dict) or _agent_position_has_mark(p): + continue + sym = p.get("symbol") or "" + side = (p.get("side") or "").lower() + for pm in position_marks: + if not isinstance(pm, dict): + continue + if not _symbols_match(sym, pm.get("symbol") or ""): + continue + if (pm.get("side") or "").lower() != side: + continue + _apply_agent_mark_price( + p, pm.get("mark_price"), pm.get("mark_price_display") + ) + break + + +def _merge_flask_exchange_tpsl(agent_row: dict, snap: dict | None, hub_mon: dict | None) -> None: + """子代理挂单为空时,用实例 Flask 已算好的 exchange_tpsl 补全展示。""" + ag = agent_row.get("agent") + if not isinstance(ag, dict): + return + positions = ag.get("positions") + if not isinstance(positions, list) or not positions: + return + if not isinstance(snap, dict): + return + order_prices = snap.get("order_prices") or [] + hub_orders = [] + if isinstance(hub_mon, dict): + hub_orders = hub_mon.get("orders") or [] + for p in positions: + if not isinstance(p, dict): + continue + sym = p.get("symbol") or "" + side = p.get("side") or "" + et = _find_exchange_tpsl_for_position(sym, side, order_prices, hub_orders) + if not et: + et = _exchange_tpsl_from_hub_order(hub_orders, sym, side) + if not et: + continue + p["exchange_tpsl"] = et + cond = p.get("conditional_orders") or [] + merged = _tpsl_slots_to_conditional_orders(et, sym) + p["conditional_orders"] = _merge_conditional_orders_no_dup(cond, merged) + + +async def _fetch_exchange_flask_bundle( + client: httpx.AsyncClient, ex: dict +) -> tuple[dict | None, dict | None, list | None, dict | None, dict | None]: + """单所 Flask:monitor / meta / price_snapshot / account(有 flask_url 时)并行拉取。""" + caps = ex.get("capabilities") or [] + tasks = [ + _fetch_flask_json(client, ex, "/api/hub/monitor"), + _fetch_flask_json(client, ex, "/api/hub/meta"), + ] + has_flask = bool((ex.get("flask_url") or "").strip()) + if has_flask: + tasks.extend( + [ + _fetch_flask_json(client, ex, "/api/price_snapshot"), + _fetch_flask_json(client, ex, "/api/hub/account"), + ] + ) + results = await asyncio.gather(*tasks) + hub_mon = results[0] + meta = results[1] + snap = results[2] if has_flask and len(results) > 2 else None + account = results[3] if has_flask and len(results) > 3 else None + key_prices = None + want_prices = HUB_BOARD_KEY_PRICES and "key" in caps + if want_prices and isinstance(snap, dict): + key_prices = snap.get("key_prices") + return ( + hub_mon, + meta, + key_prices, + snap if isinstance(snap, dict) else None, + account if isinstance(account, dict) else None, + ) + + +async def _assemble_board_row( + client: httpx.AsyncClient, ex: dict, agent_row: dict +) -> dict: + hub_mon, meta, key_prices, snap, account = await _fetch_exchange_flask_bundle( + client, ex + ) + if isinstance(hub_mon, dict): + _merge_flask_order_price_fields(hub_mon, snap) + _merge_flask_exchange_tpsl(agent_row, snap, hub_mon if isinstance(hub_mon, dict) else None) + _merge_flask_position_breakeven(agent_row, snap, hub_mon if isinstance(hub_mon, dict) else None) + _merge_flask_position_mark_price(agent_row, snap, hub_mon if isinstance(hub_mon, dict) else None) + flask_ok = isinstance(hub_mon, dict) and hub_mon.get("ok") is not False + acct_ok = isinstance(account, dict) and account.get("ok") is not False + raw_review = (ex.get("review_url") or "").strip() + review_link = browser_url(raw_review) if raw_review else default_review_url( + ex.get("flask_url") + ) + return { + **agent_row, + "flask_url": ex.get("flask_url") or "", + "flask_url_browser": browser_url(ex.get("flask_url")), + "review_url": review_link, + "hub_monitor": hub_mon, + "flask_ok": flask_ok, + "flask_error": _flask_error_from_hub_mon(hub_mon if isinstance(hub_mon, dict) else None), + "meta": (meta or {}).get("meta") if isinstance(meta, dict) else meta, + "key_prices": key_prices, + "funding_usdt": account.get("funding_usdt") if acct_ok else None, + "trading_usdt": account.get("trading_usdt") if acct_ok else None, + "available_trading_usdt": account.get("available_trading_usdt") if acct_ok else None, + "account_ok": acct_ok, + } + + +async def _build_monitor_board_payload() -> dict: + exchanges = enabled_exchanges() + async with httpx.AsyncClient() as client: + agent_rows = await asyncio.gather( + *[_fetch_agent_status(client, ex) for ex in exchanges] + ) + out = await asyncio.gather( + *[ + _assemble_board_row(client, ex, agent_row) + for ex, agent_row in zip(exchanges, agent_rows) + ] + ) + return { + "rows": list(out), + "updated_at": __import__("datetime").datetime.now().isoformat(timespec="seconds"), + } + + +@app.get("/api/monitor/board") +@app.get("/api/monitor/board/snapshot") +async def api_monitor_board_snapshot(): + """读后台缓存快照;完整聚合由 hub 每 HUB_BOARD_POLL_INTERVAL 秒执行。""" + return board_store.snapshot_dict() + + +@app.get("/api/monitor/board/stream") +async def api_monitor_board_stream(): + from fastapi.responses import StreamingResponse + + return StreamingResponse( + board_store.iter_sse(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + +@app.post("/api/monitor/board/refresh") +async def api_monitor_board_refresh(): + _schedule_board_refresh() + return {"ok": True, "board_version": board_store.version} + + +@app.get("/api/host/status") +async def api_host_status(): + from lib.hub.hub_host_status_lib import get_host_status + + return await asyncio.to_thread(get_host_status) + + +def _require_hub_logged_in(request: Request) -> None: + if password_required() and not validate_session_token(request.cookies.get(SESSION_COOKIE)): + raise HTTPException(status_code=401, detail="未登录中控") + + +@app.get("/api/instance/open-url") +def api_instance_open_url( + request: Request, + exchange_id: str, + next: str = "/", + embed: str = "", + hub_theme: str = "", +): + """已登录中控时生成实例 SSO 打开链接(2h 有效、单次使用,复用 HUB_BRIDGE_TOKEN)。""" + _require_hub_logged_in(request) + if not HUB_BRIDGE_TOKEN: + raise HTTPException(status_code=503, detail="未配置 HUB_BRIDGE_TOKEN,无法签发实例打开链接") + ex = _find_exchange(exchange_id) + if not ex: + raise HTTPException(status_code=404, detail="未知交易所 id") + base = browser_url((ex.get("flask_url") or "").strip()).rstrip("/") + if not base: + raise HTTPException(status_code=400, detail="该账户未配置 flask_url") + ex_key = (ex.get("key") or "").strip().lower() + if not ex_key: + raise HTTPException(status_code=400, detail="该账户缺少 key(用于 SSO 校验)") + nxt = safe_next_path(next) + token = mint_hub_sso_token(ex_key, nxt) + if not token: + raise HTTPException(status_code=503, detail="签发 SSO 失败") + params = {"token": token, "next": nxt} + if (embed or "").strip().lower() in ("1", "true", "yes", "on"): + params["embed"] = "1" + ht = (hub_theme or "").strip().lower() + if ht in ("light", "dark"): + params["hub_theme"] = ht + q = urlencode(params) + return { + "ok": True, + "url": f"{base}/hub-sso?{q}", + "expires_in": HUB_SSO_TTL_SEC, + "exchange_id": exchange_id, + "exchange_key": ex_key, + } + + +class CloseAllBody(BaseModel): + exclude_ids: list[str] = Field(default_factory=list) + + +class ClosePositionBody(BaseModel): + symbol: str + side: str + + +class CancelOrderBody(BaseModel): + symbol: str + order_id: str + channel: str = "regular" + + +class CancelSymbolOrdersBody(BaseModel): + symbol: str + scope: str = "all" + + +class PlaceTpslBody(BaseModel): + symbol: str + side: str + stop_loss: float + take_profit: float + contracts: float | None = None + + +class TrendPlanActionBody(BaseModel): + plan_id: int + breakeven_offset_pct: float | None = None + + +def _flask_hub_messages(parsed: dict | None) -> tuple[bool, str]: + if not isinstance(parsed, dict): + return False, "实例返回无效" + msgs = list(parsed.get("messages") or []) + if parsed.get("msg"): + msgs.insert(0, str(parsed["msg"])) + if parsed.get("error"): + msgs.append(str(parsed["error"])) + ok = parsed.get("ok") is not False + if parsed.get("ok") is True: + ok = True + elif parsed.get("ok") is False: + ok = False + else: + for m in msgs: + if any( + k in str(m) + for k in ("失败", "错误", "无法", "缺少", "过期", "未找到", "不允许", "异常") + ): + ok = False + break + text = ";".join(str(x) for x in msgs if x) or ("成功" if ok else "操作失败") + return ok, text + + +@app.post("/api/trend/{exchange_id}/stop") +async def api_trend_plan_stop(exchange_id: str, body: TrendPlanActionBody): + ex = _find_exchange(exchange_id) + if not ex or not ex.get("enabled"): + raise HTTPException(status_code=404, detail="账户未启用") + if "trend" not in (ex.get("capabilities") or []): + raise HTTPException(status_code=400, detail="该账户未启用趋势计划监控") + pid = int(body.plan_id) + async with httpx.AsyncClient() as client: + parsed = await _fetch_flask_json( + client, ex, f"/api/hub/trend/stop/{pid}", method="POST" + ) + ok, text = _flask_hub_messages(parsed) + _schedule_board_refresh() + return {"ok": ok, "message": text, "payload": parsed} + + +@app.post("/api/trend/{exchange_id}/breakeven") +async def api_trend_plan_breakeven(exchange_id: str, body: TrendPlanActionBody): + ex = _find_exchange(exchange_id) + if not ex or not ex.get("enabled"): + raise HTTPException(status_code=404, detail="账户未启用") + if "trend" not in (ex.get("capabilities") or []): + raise HTTPException(status_code=400, detail="该账户未启用趋势计划监控") + pid = int(body.plan_id) + data = {} + if body.breakeven_offset_pct is not None: + data["breakeven_offset_pct"] = str(body.breakeven_offset_pct) + async with httpx.AsyncClient() as client: + parsed = await _fetch_flask_json( + client, + ex, + f"/api/hub/trend/breakeven/{pid}", + method="POST", + data=data, + ) + ok, text = _flask_hub_messages(parsed) + _schedule_board_refresh() + return {"ok": ok, "message": text, "payload": parsed} + + +@app.post("/api/orders/{exchange_id}/cancel") +async def api_cancel_order(exchange_id: str, body: CancelOrderBody): + ex = _find_exchange(exchange_id) + if not ex or not ex.get("enabled"): + raise HTTPException(status_code=404, detail="账户未启用") + url = f"{ex['agent_url'].rstrip('/')}/orders/cancel" + async with httpx.AsyncClient() as client: + r = await client.post( + url, + headers=_agent_headers(), + json={ + "symbol": body.symbol, + "order_id": body.order_id, + "channel": body.channel or "regular", + }, + timeout=60.0, + ) + try: + payload = r.json() + except Exception: + payload = {"raw": (r.text or "")[:2000]} + out = { + "exchange": ex, + "status_code": r.status_code, + "payload": payload, + "ok": bool(isinstance(payload, dict) and payload.get("ok")), + } + _schedule_board_refresh() + return out + + +@app.post("/api/orders/{exchange_id}/cancel-symbol") +async def api_cancel_symbol_orders(exchange_id: str, body: CancelSymbolOrdersBody): + ex = _find_exchange(exchange_id) + if not ex or not ex.get("enabled"): + raise HTTPException(status_code=404, detail="账户未启用") + url = f"{ex['agent_url'].rstrip('/')}/orders/cancel-symbol" + async with httpx.AsyncClient() as client: + r = await client.post( + url, + headers=_agent_headers(), + json={"symbol": body.symbol, "scope": body.scope or "all"}, + timeout=120.0, + ) + try: + payload = r.json() + except Exception: + payload = {"raw": (r.text or "")[:2000]} + out = { + "exchange": ex, + "status_code": r.status_code, + "payload": payload, + "ok": bool(isinstance(payload, dict) and payload.get("ok")), + } + _schedule_board_refresh() + return out + + +@app.post("/api/close/{exchange_id}/position") +async def api_close_position(exchange_id: str, body: ClosePositionBody): + ex = _find_exchange(exchange_id) + if not ex or not ex.get("enabled"): + raise HTTPException(status_code=404, detail="账户未启用") + sym = (body.symbol or "").strip() + side = (body.side or "").strip().lower() + if not sym: + raise HTTPException(status_code=400, detail="symbol 不能为空") + if side not in ("long", "short"): + raise HTTPException(status_code=400, detail="side 须为 long 或 short") + url = f"{ex['agent_url'].rstrip('/')}/emergency/close-position" + async with httpx.AsyncClient() as client: + r = await client.post( + url, + headers=_agent_headers(), + json={"symbol": sym, "side": side}, + timeout=120.0, + ) + try: + payload = r.json() + except Exception: + payload = {"raw": (r.text or "")[:2000]} + out = { + "exchange": ex, + "status_code": r.status_code, + "payload": payload, + "ok": bool(isinstance(payload, dict) and payload.get("ok")), + } + if out.get("ok"): + ex_key = (ex.get("key") or "").strip().lower() + async with httpx.AsyncClient() as flask_client: + if ex_key in ("gate", "gate_bot"): + order_sync = await _fetch_flask_json( + flask_client, + ex, + "/api/hub/order/sync-flat", + method="POST", + json_body={"symbol": sym, "side": side}, + ) + if isinstance(order_sync, dict): + out["order_sync"] = order_sync + if "trend" in (ex.get("capabilities") or []): + sync_parsed = await _fetch_flask_json( + flask_client, + ex, + "/api/hub/trend/sync-flat", + method="POST", + json_body={"symbol": sym, "side": side}, + ) + if isinstance(sync_parsed, dict): + out["trend_sync"] = sync_parsed + roll_sync = await _fetch_flask_json( + flask_client, + ex, + "/api/hub/roll/sync-flat", + method="POST", + json_body={"symbol": sym, "side": side}, + ) + if isinstance(roll_sync, dict): + out["roll_sync"] = roll_sync + risk_sync = await _notify_instance_user_close(flask_client, ex, count=1) + if isinstance(risk_sync, dict): + out["risk_sync"] = risk_sync + _schedule_board_refresh() + return out + + +@app.post("/api/orders/{exchange_id}/place-tpsl") +async def api_place_tpsl(exchange_id: str, body: PlaceTpslBody): + ex = _find_exchange(exchange_id) + if not ex or not ex.get("enabled"): + raise HTTPException(status_code=404, detail="账户未启用") + url = f"{ex['agent_url'].rstrip('/')}/orders/place-tpsl" + async with httpx.AsyncClient() as client: + r = await client.post( + url, + headers=_agent_headers(), + json={ + "symbol": body.symbol, + "side": body.side, + "stop_loss": body.stop_loss, + "take_profit": body.take_profit, + "contracts": body.contracts, + }, + timeout=120.0, + ) + try: + payload = r.json() + except Exception: + payload = {"raw": (r.text or "")[:2000]} + out = { + "exchange": ex, + "status_code": r.status_code, + "payload": payload, + "ok": bool(isinstance(payload, dict) and payload.get("ok")), + } + _schedule_board_refresh() + return out + + +@app.post("/api/close/{exchange_id}") +async def api_close_exchange(exchange_id: str): + ex = _find_exchange(exchange_id) + if not ex or not ex.get("enabled"): + raise HTTPException(status_code=404, detail="账户未启用") + url = f"{ex['agent_url'].rstrip('/')}/emergency/close-all" + async with httpx.AsyncClient() as client: + r = await client.post(url, headers=_agent_headers(), timeout=120.0) + try: + body = r.json() + except Exception: + body = {"raw": (r.text or "")[:2000]} + ok = bool(isinstance(body, dict) and body.get("ok")) + out = {"exchange": ex, "status_code": r.status_code, "payload": body, "ok": ok} + if ok and isinstance(body, dict): + closed = body.get("closed") or [] + n = len(closed) if isinstance(closed, list) else 0 + if n > 0: + risk_sync = await _notify_instance_user_close(client, ex, count=n) + if isinstance(risk_sync, dict): + out["risk_sync"] = risk_sync + _schedule_board_refresh() + return out + + +@app.post("/api/close-all") +async def api_close_all(body: CloseAllBody | None = Body(default=None)): + excl = set(body.exclude_ids if body else []) + excl |= env_force_disabled_ids() + targets = [x for x in enabled_exchanges() if str(x["id"]) not in excl] + async with httpx.AsyncClient() as client: + + async def one(ex: dict): + url = f"{ex['agent_url'].rstrip('/')}/emergency/close-all" + try: + r = await client.post(url, headers=_agent_headers(), timeout=120.0) + try: + payload = r.json() + except Exception: + payload = {"raw": (r.text or "")[:2000]} + row = {"id": ex["id"], "name": ex["name"], "status_code": r.status_code, "payload": payload} + if isinstance(payload, dict) and payload.get("ok"): + closed = payload.get("closed") or [] + n = len(closed) if isinstance(closed, list) else 0 + if n > 0: + risk_sync = await _notify_instance_user_close(client, ex, count=n) + if isinstance(risk_sync, dict): + row["risk_sync"] = risk_sync + return row + except Exception as e: + return {"id": ex["id"], "name": ex["name"], "status_code": None, "error": str(e)} + + results = await asyncio.gather(*[one(ex) for ex in targets]) + _schedule_board_refresh() + return {"results": list(results)} + + +def _trade_removed_response(): + """旧版前端或缓存页面仍会请求 /api/trade/*,勿解析表单,直接返回说明。""" + return JSONResponse( + { + "ok": False, + "result": { + "ok": False, + "messages": [ + "中控已移除下单区。请在监控卡片点击「实例」," + "进入对应 crypto_monitor_* 网页添加关键位或下单。" + ], + }, + "deprecated": True, + }, + status_code=410, + ) + + +def _parse_anchor_ms(at: str = "", anchor_ms: str = "") -> int | None: + raw = (anchor_ms or at or "").strip() + if not raw: + return None + return parse_wall_clock_ms(raw) + + +@app.get("/api/archive/meta") +def api_archive_meta(): + init_archive_db() + exchanges = [] + for ex in enabled_exchanges(load_settings()): + exchanges.append( + { + "id": ex.get("id"), + "key": ex.get("key"), + "name": ex.get("name"), + } + ) + return { + "ok": True, + "timeframes": sorted(ARCHIVE_TIMEFRAMES), + "default_timeframe": ARCHIVE_DEFAULT_TIMEFRAME, + "seed_lookback_days": ARCHIVE_SEED_LOOKBACK_DAYS, + "sync_interval_sec": ARCHIVE_SYNC_INTERVAL_SEC, + "visible_bars_default": ARCHIVE_VISIBLE_BARS_DEFAULT, + "exchanges": exchanges, + "last_sync": _last_archive_sync, + } + + +@app.get("/api/archive/list") +def api_archive_list( + exchange_key: str = "", + filter_profit: str = "", + filter_loss: str = "", + filter_sick: str = "", + filter_emotion: str = "", +): + init_archive_db() + rows = list_symbol_rows( + exchange_key=exchange_key, + filter_profit=(filter_profit or "").lower() in ("1", "true", "yes", "on"), + filter_loss=(filter_loss or "").lower() in ("1", "true", "yes", "on"), + filter_sick=(filter_sick or "").lower() in ("1", "true", "yes", "on"), + filter_emotion=(filter_emotion or "").lower() in ("1", "true", "yes", "on"), + ) + return {"ok": True, "rows": rows, "count": len(rows)} + + +@app.get("/api/archive/daily-trades") +def api_archive_daily_trades( + period: str = "", + trading_day: str = "", + date_from: str = "", + date_to: str = "", + exchange_key: str = "", + filter_profit: str = "", + filter_loss: str = "", + filter_sick: str = "", + search: str = "", +): + init_archive_db() + payload = list_daily_trades( + trading_day=trading_day, + period=period or "today", + date_from=date_from, + date_to=date_to, + exchange_key=exchange_key, + filter_profit=(filter_profit or "").lower() in ("1", "true", "yes", "on"), + filter_loss=(filter_loss or "").lower() in ("1", "true", "yes", "on"), + filter_sick=(filter_sick or "").lower() in ("1", "true", "yes", "on"), + search=search, + ) + return {"ok": True, **payload} + + +@app.get("/api/archive/calendar") +def api_archive_calendar( + year: int = 0, + month: int = 0, + exchange_key: str = "", +): + init_archive_db() + if year <= 0 or month <= 0: + td = today_trading_day() + parts = td.split("-") + year = int(parts[0]) + month = int(parts[1]) + try: + payload = list_archive_calendar(year, month, exchange_key=exchange_key) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + return {"ok": True, **payload} + + +@app.get("/api/archive/quotes") +def api_archive_quotes(): + init_archive_db() + rows = list_review_quotes() + return {"ok": True, "quotes": rows, "count": len(rows), "max": ARCHIVE_QUOTES_MAX} + + +class ArchiveQuoteBody(BaseModel): + quote_date: str = "" + content: str = "" + + +@app.post("/api/archive/quotes") +def api_archive_quote_create(body: ArchiveQuoteBody = Body(...)): + init_archive_db() + try: + row = create_review_quote(body.quote_date, body.content) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + return {"ok": True, "quote": row} + + +@app.patch("/api/archive/quotes/{quote_id}") +def api_archive_quote_update(quote_id: int, body: ArchiveQuoteBody = Body(...)): + init_archive_db() + try: + row = update_review_quote( + int(quote_id), + quote_date=body.quote_date or None, + content=body.content if body.content is not None else None, + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + if not row: + raise HTTPException(status_code=404, detail="语录不存在") + return {"ok": True, "quote": row} + + +@app.delete("/api/archive/quotes/{quote_id}") +def api_archive_quote_delete(quote_id: int): + init_archive_db() + if not delete_review_quote(int(quote_id)): + raise HTTPException(status_code=404, detail="语录不存在") + return {"ok": True, "id": int(quote_id)} + + +class MacroEventBody(BaseModel): + event_type: str = "" + event_at: str = "" + note: str = "" + + +@app.get("/api/macro-calendar/meta") +def api_macro_calendar_meta(): + init_macro_calendar_db() + return { + "ok": True, + "event_types": [ + {"id": k, "label": MACRO_EVENT_LABELS[k]} for k in MACRO_EVENT_TYPES + ], + "window_before_minutes": 60, + "window_after_minutes": 60, + "timezone": "Asia/Shanghai", + } + + +@app.get("/api/macro-calendar/events") +def api_macro_calendar_events(): + init_macro_calendar_db() + rows = list_macro_events() + return {"ok": True, "events": rows, "count": len(rows)} + + +@app.get("/api/macro-calendar/active") +def api_macro_calendar_active(): + init_macro_calendar_db() + alerts = list_active_alerts() + return {"ok": True, "alerts": alerts, "count": len(alerts)} + + +@app.post("/api/macro-calendar/events") +def api_macro_calendar_create(body: MacroEventBody = Body(...)): + init_macro_calendar_db() + try: + row = create_macro_event(body.event_type, body.event_at, note=body.note) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + return {"ok": True, "event": row} + + +@app.patch("/api/macro-calendar/events/{event_id}") +def api_macro_calendar_update(event_id: int, body: MacroEventBody = Body(...)): + init_macro_calendar_db() + try: + row = update_macro_event( + int(event_id), + event_type=body.event_type or None, + event_at=body.event_at or None, + note=body.note, + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + if not row: + raise HTTPException(status_code=404, detail="记录不存在") + return {"ok": True, "event": row} + + +@app.delete("/api/macro-calendar/events/{event_id}") +def api_macro_calendar_delete(event_id: int): + init_macro_calendar_db() + if not delete_macro_event(int(event_id)): + raise HTTPException(status_code=404, detail="记录不存在") + return {"ok": True, "id": int(event_id)} + + +@app.get("/api/archive/detail") +def api_archive_detail(exchange_key: str = "", symbol: str = ""): + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + if not ex_k or not sym: + raise HTTPException(status_code=400, detail="缺少 exchange_key 或 symbol") + init_archive_db() + trades = load_symbol_trades(ex_k, sym) + return {"ok": True, "exchange_key": ex_k, "symbol": sym, "trades": trades} + + +@app.get("/api/archive/ohlcv") +def api_archive_ohlcv( + exchange_key: str = "", + symbol: str = "", + timeframe: str = ARCHIVE_DEFAULT_TIMEFRAME, + mode: str = "hold", + anchor_ms: str = "", + opened_ms: str = "", + closed_ms: str = "", + range: str = "", + at: str = "", + bars: str = "", +): + ex_k = (exchange_key or "").strip().lower() + sym = (symbol or "").strip().upper() + if not ex_k or not sym: + raise HTTPException(status_code=400, detail="缺少 exchange_key 或 symbol") + init_archive_db() + anchor = _parse_anchor_ms(at, anchor_ms) + open_ms = _parse_anchor_ms("", opened_ms) + close_ms = _parse_anchor_ms("", closed_ms) + try: + bar_n = int(bars) if (bars or "").strip().isdigit() else ARCHIVE_VISIBLE_BARS_DEFAULT + except ValueError: + bar_n = ARCHIVE_VISIBLE_BARS_DEFAULT + result = resolve_archive_chart( + ex_k, + sym, + timeframe, + anchor_ms=anchor, + opened_ms=open_ms, + closed_ms=close_ms, + mode=mode, + bars=bar_n, + range_mode=(range or "").strip().lower() or "window", + ) + if not result.get("ok"): + raise HTTPException(status_code=404, detail=result.get("msg") or "无 K 线") + return result + + +class ArchiveOverlayBody(BaseModel): + behavior_tag: str = "" + note: str = "" + + +@app.patch("/api/archive/trade/{exchange_key}/{trade_id}") +def api_archive_trade_overlay( + exchange_key: str, + trade_id: int, + body: ArchiveOverlayBody = Body(...), +): + ex_k = (exchange_key or "").strip().lower() + if not ex_k: + raise HTTPException(status_code=400, detail="缺少 exchange_key") + init_archive_db() + out = upsert_trade_overlay( + ex_k, + int(trade_id), + behavior_tag=body.behavior_tag, + note=body.note, + ) + return {"ok": True, "overlay": out} + + +@app.delete("/api/archive/trade/{exchange_key}/{trade_id}") +def api_archive_trade_delete(exchange_key: str, trade_id: int): + from lib.hub.hub_symbol_archive_lib import delete_trade_from_archive + + ex_k = (exchange_key or "").strip().lower() + if not ex_k: + raise HTTPException(status_code=400, detail="缺少 exchange_key") + init_archive_db() + removed = delete_trade_from_archive(ex_k, int(trade_id)) + if not removed: + raise HTTPException(status_code=404, detail="档案中无该笔交易") + return {"ok": True, "exchange_key": ex_k, "trade_id": int(trade_id)} + + +@app.post("/api/archive/sync") +async def api_archive_sync(): + body = await _run_archive_sync_once() + return body + + +@app.get("/api/entry-plans/meta") +def api_entry_plans_meta(): + init_entry_plan_db() + exchanges = [] + for ex in enabled_exchanges(load_settings()): + exchanges.append( + { + "id": ex.get("id"), + "key": ex.get("key"), + "name": ex.get("name"), + } + ) + return {"ok": True, **entry_plan_meta_payload(exchanges)} + + +@app.get("/api/entry-plans") +def api_entry_plans_list(status: str = "active"): + init_entry_plan_db() + try: + rows = list_entry_plans(status=status) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + return {"ok": True, "plans": rows, "count": len(rows), "status": status.strip().lower()} + + +@app.get("/api/entry-plans/stats") +def api_entry_plan_stats( + dimension: str = "symbol", + period: str = "all", + date_from: str = "", + date_to: str = "", +): + init_entry_plan_db() + try: + stats = compute_entry_plan_stats( + dimension=dimension, + period=period, + date_from=date_from, + date_to=date_to, + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + return {"ok": True, "stats": stats} + + +@app.get("/api/entry-plans/{plan_id}") +def api_entry_plan_detail(plan_id: int): + init_entry_plan_db() + row = get_entry_plan(int(plan_id)) + if not row: + raise HTTPException(status_code=404, detail="计划不存在") + return {"ok": True, "plan": row} + + +class EntryPlanBody(BaseModel): + plan_date: str = "" + exchange_key: str = "" + symbol: str = "" + plan_type: str = "" + trend_timeframe: str = "" + entry_timeframe: str = "" + direction: str = "" + target_level: str = "" + current_range: str = "" + entry_scheme: str = "" + result: str | None = None + pnl_amount: float | None = None + note: str = "" + + +@app.post("/api/entry-plans") +def api_entry_plan_create(body: EntryPlanBody = Body(...)): + init_entry_plan_db() + try: + row = create_entry_plan(body.model_dump(exclude_unset=True)) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + return {"ok": True, "plan": row} + + +@app.patch("/api/entry-plans/{plan_id}") +def api_entry_plan_update(plan_id: int, body: EntryPlanBody = Body(...)): + init_entry_plan_db() + payload = body.model_dump(exclude_unset=True) + if not payload: + raise HTTPException(status_code=400, detail="无更新字段") + try: + row = update_entry_plan(int(plan_id), payload) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + if not row: + raise HTTPException(status_code=404, detail="计划不存在") + return {"ok": True, "plan": row} + + +@app.delete("/api/entry-plans/{plan_id}") +def api_entry_plan_delete(plan_id: int): + init_entry_plan_db() + try: + ok = delete_entry_plan(int(plan_id)) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + if not ok: + raise HTTPException(status_code=404, detail="计划不存在或已归档") + return {"ok": True, "id": int(plan_id)} + + +@app.get("/api/hub/fund-overview") +def api_hub_fund_overview(): + from lib.hub.hub_fund_history_lib import build_fund_overview + from hub_ai.config import trading_day_reset_hour + + settings = load_settings() + snap = board_store.snapshot_dict() + payload = build_fund_overview( + enabled_exchanges(settings), + board_rows=snap.get("rows") or [], + reset_hour=trading_day_reset_hour(), + updated_at=snap.get("updated_at"), + ) + return payload + + +@app.get("/api/ping") +def api_ping(): + return { + "ok": True, + "service": "manual-trading-hub", + "build": HUB_BUILD, + "trade_ui": False, + "features": ["monitor", "settings", "auth", "board_sse", "dashboard_sse", "archive", "dashboard", "funds", "macro_calendar"], + "board_poll_interval_sec": HUB_BOARD_POLL_INTERVAL, + "board_version": board_store.version, + "board_aggregating": board_store.aggregating, + "board_updated_at": (board_store.payload or {}).get("updated_at") + if isinstance(board_store.payload, dict) + else None, + "board_error": board_store.last_error, + "dashboard_poll_interval_sec": DASHBOARD_POLL_INTERVAL_SEC, + "dashboard_version": dashboard_store.version, + "dashboard_aggregating": dashboard_store.aggregating, + "dashboard_updated_at": (dashboard_store.payload or {}).get("updated_at") + if isinstance(dashboard_store.payload, dict) + else None, + "dashboard_error": dashboard_store.last_error, + "password_required": password_required(), + "env_disabled_ids": sorted(env_force_disabled_ids()), + "hub_disabled_ids_raw": (os.getenv("HUB_DISABLED_IDS") or ""), + } + + +@app.post("/api/trade/order/{exchange_id}") +@app.post("/api/trade/key/{exchange_id}") +@app.post("/api/trade/trend/preview/{exchange_id}") +@app.post("/api/trade/trend/execute/{exchange_id}") +async def api_trade_removed(exchange_id: str): + return _trade_removed_response() + + +@app.get("/api/trade/meta/{exchange_id}") +@app.get("/api/trade/trend/preview/{exchange_id}/{preview_id}") +async def api_trade_removed_get(exchange_id: str, preview_id: str = ""): + return _trade_removed_response() + + +def main(): + import uvicorn + + print( + f"manual-trading-hub start build={HUB_BUILD} listen={HUB_HOST}:{HUB_PORT}", + flush=True, + ) + uvicorn.run(app, host=HUB_HOST, port=HUB_PORT, log_level="info", access_log=False) + + +if __name__ == "__main__": + main() diff --git a/manual_trading_hub/hub_ai/archive_quote.py b/manual_trading_hub/hub_ai/archive_quote.py index 6dac822..95af2ca 100644 --- a/manual_trading_hub/hub_ai/archive_quote.py +++ b/manual_trading_hub/hub_ai/archive_quote.py @@ -1,161 +1,161 @@ -"""内照明心复盘语录 → 交易教练点评。""" -from __future__ import annotations - -from typing import Any - -from hub_ai.client import generate_text, model_label -from hub_ai.rolling_summary import refresh_session_rolling_summary -from hub_ai.text_util import clip_text, is_ai_error_reply -from hub_ai.config import ( - CHAT_MAX_CONTINUATIONS, - CHAT_MAX_OUTPUT_TOKENS, - CHAT_TEMPERATURE, - CHAT_USER_MESSAGE_MAX_CHARS, -) -from hub_ai.prompts import CHAT_SYSTEM, build_archive_quote_review_prompt -from hub_ai.store import ( - CHAT_BOT_TRADING, - append_chat_message, - create_new_session, - delete_chat_session, - get_active_session, - list_chat_sessions, -) -from hub_symbol_archive_lib import list_daily_trades - - -def _tag_label(tag: str) -> str: - t = (tag or "").strip().lower() - if t == "sick": - return "犯病" - if t == "emotion": - return "情绪化" - return t or "—" - - -def _fmt_pnl(v: Any) -> str: - try: - n = float(v or 0) - except (TypeError, ValueError): - return "—" - sign = "+" if n > 0 else "" - return f"{sign}{n:.2f}U" - - -def _fmt_pct(v: Any) -> str: - try: - n = float(v) - except (TypeError, ValueError): - return "—" - return f"{n:.1f}%" - - -def _fmt_rr(v: Any) -> str: - try: - n = float(v) - except (TypeError, ValueError): - return "—" - return f"{n:.2f}:1" - - -def format_archive_trades_for_ai(payload: dict[str, Any]) -> str: - trades = payload.get("trades") or [] - stats = payload.get("stats") or {} - lines = [ - ( - f"统计:开仓 {int(stats.get('open_count') or 0)} 笔," - f"盈利 {int(stats.get('win_count') or 0)} / 亏损 {int(stats.get('loss_count') or 0)}," - f"平均盈利 {_fmt_pnl(stats.get('avg_win'))},平均亏损 {_fmt_pnl(stats.get('avg_loss'))}," - f"胜率 {_fmt_pct(stats.get('win_rate'))},盈亏比 {_fmt_rr(stats.get('profit_loss_ratio'))}," - f"最大盈利 {_fmt_pnl(stats.get('max_win'))},最大亏损 {_fmt_pnl(stats.get('max_loss'))}," - f"犯病 {int(stats.get('sick_count') or 0)} 笔," - f"盈亏合计 {_fmt_pnl(stats.get('pnl_total'))}," - f"剔除犯病盈亏 {_fmt_pnl(stats.get('pnl_ex_sick'))}" - ) - ] - if not trades: - lines.append("(该日无交易记录)") - return "\n".join(lines) - max_rows = 50 - if len(trades) > max_rows: - lines.append(f"(共 {len(trades)} 笔,以下展示最近 {max_rows} 笔)") - for i, t in enumerate(trades[:max_rows], 1): - ex = str(t.get("exchange_key") or t.get("account_exchange_key") or "—") - sym = str(t.get("symbol") or "—") - direction = str(t.get("direction") or "—") - opened = str(t.get("opened_at") or "—") - closed = str(t.get("closed_at") or "—") - hold = str(t.get("hold_minutes_text") or t.get("hold_minutes") or "—") - result = str(t.get("result") or "—") - pnl = _fmt_pnl(t.get("pnl_amount")) - entry = str(t.get("entry_type") or t.get("entry_reason") or t.get("monitor_type") or "—") - tag = _tag_label(str(t.get("behavior_tag") or "")) - note = clip_text(str(t.get("note") or "").strip(), 80) - line = ( - f"{i}. {ex} | {sym} | {direction} | 开仓类型 {entry} | " - f"开 {opened} | 平 {closed} | 持仓 {hold} | 结果 {result} | " - f"盈亏 {pnl} | 标签 {tag}" - ) - if note: - line += f" | 备注 {note}" - lines.append(line) - return "\n".join(lines) - - -def send_archive_quote_review( - *, - quote_date: str, - content: str, -) -> dict[str, Any]: - text = (content or "").strip() - if not text: - return {"ok": False, "msg": "语录内容不能为空"} - day = (quote_date or "").strip()[:10] - if not day: - return {"ok": False, "msg": "语录日期无效"} - - session = create_new_session( - trading_day=day, - title=f"复盘 {day}", - bot_mode=CHAT_BOT_TRADING, - ) - sid = session["id"] - - archive_payload = list_daily_trades(trading_day=day, period="today") - archive_trades_text = format_archive_trades_for_ai(archive_payload) - user_for_prompt = clip_text(text, CHAT_USER_MESSAGE_MAX_CHARS) - - user_prompt = build_archive_quote_review_prompt( - quote_date=day, - archive_trades_text=archive_trades_text, - user_message=user_for_prompt, - ) - reply = generate_text( - system=CHAT_SYSTEM, - user=user_prompt, - temperature=CHAT_TEMPERATURE, - max_tokens=CHAT_MAX_OUTPUT_TOKENS, - max_continuations=CHAT_MAX_CONTINUATIONS, - ) - if is_ai_error_reply(reply): - delete_chat_session(sid) - return {"ok": False, "msg": reply} - - append_chat_message(sid, "user", text) - session = append_chat_message(sid, "assistant", reply) - refresh_session_rolling_summary( - sid, - prior_summary="", - user_text=text, - assistant_text=reply, - bot_mode=CHAT_BOT_TRADING, - ) - session = get_active_session() or session - return { - "ok": True, - "trading_day": day, - "session": session, - "sessions": list_chat_sessions(), - "reply": reply, - "model": model_label(), - } +"""内照明心复盘语录 → 交易教练点评。""" +from __future__ import annotations + +from typing import Any + +from hub_ai.client import generate_text, model_label +from hub_ai.rolling_summary import refresh_session_rolling_summary +from hub_ai.text_util import clip_text, is_ai_error_reply +from hub_ai.config import ( + CHAT_MAX_CONTINUATIONS, + CHAT_MAX_OUTPUT_TOKENS, + CHAT_TEMPERATURE, + CHAT_USER_MESSAGE_MAX_CHARS, +) +from hub_ai.prompts import CHAT_SYSTEM, build_archive_quote_review_prompt +from hub_ai.store import ( + CHAT_BOT_TRADING, + append_chat_message, + create_new_session, + delete_chat_session, + get_active_session, + list_chat_sessions, +) +from lib.hub.hub_symbol_archive_lib import list_daily_trades + + +def _tag_label(tag: str) -> str: + t = (tag or "").strip().lower() + if t == "sick": + return "犯病" + if t == "emotion": + return "情绪化" + return t or "—" + + +def _fmt_pnl(v: Any) -> str: + try: + n = float(v or 0) + except (TypeError, ValueError): + return "—" + sign = "+" if n > 0 else "" + return f"{sign}{n:.2f}U" + + +def _fmt_pct(v: Any) -> str: + try: + n = float(v) + except (TypeError, ValueError): + return "—" + return f"{n:.1f}%" + + +def _fmt_rr(v: Any) -> str: + try: + n = float(v) + except (TypeError, ValueError): + return "—" + return f"{n:.2f}:1" + + +def format_archive_trades_for_ai(payload: dict[str, Any]) -> str: + trades = payload.get("trades") or [] + stats = payload.get("stats") or {} + lines = [ + ( + f"统计:开仓 {int(stats.get('open_count') or 0)} 笔," + f"盈利 {int(stats.get('win_count') or 0)} / 亏损 {int(stats.get('loss_count') or 0)}," + f"平均盈利 {_fmt_pnl(stats.get('avg_win'))},平均亏损 {_fmt_pnl(stats.get('avg_loss'))}," + f"胜率 {_fmt_pct(stats.get('win_rate'))},盈亏比 {_fmt_rr(stats.get('profit_loss_ratio'))}," + f"最大盈利 {_fmt_pnl(stats.get('max_win'))},最大亏损 {_fmt_pnl(stats.get('max_loss'))}," + f"犯病 {int(stats.get('sick_count') or 0)} 笔," + f"盈亏合计 {_fmt_pnl(stats.get('pnl_total'))}," + f"剔除犯病盈亏 {_fmt_pnl(stats.get('pnl_ex_sick'))}" + ) + ] + if not trades: + lines.append("(该日无交易记录)") + return "\n".join(lines) + max_rows = 50 + if len(trades) > max_rows: + lines.append(f"(共 {len(trades)} 笔,以下展示最近 {max_rows} 笔)") + for i, t in enumerate(trades[:max_rows], 1): + ex = str(t.get("exchange_key") or t.get("account_exchange_key") or "—") + sym = str(t.get("symbol") or "—") + direction = str(t.get("direction") or "—") + opened = str(t.get("opened_at") or "—") + closed = str(t.get("closed_at") or "—") + hold = str(t.get("hold_minutes_text") or t.get("hold_minutes") or "—") + result = str(t.get("result") or "—") + pnl = _fmt_pnl(t.get("pnl_amount")) + entry = str(t.get("entry_type") or t.get("entry_reason") or t.get("monitor_type") or "—") + tag = _tag_label(str(t.get("behavior_tag") or "")) + note = clip_text(str(t.get("note") or "").strip(), 80) + line = ( + f"{i}. {ex} | {sym} | {direction} | 开仓类型 {entry} | " + f"开 {opened} | 平 {closed} | 持仓 {hold} | 结果 {result} | " + f"盈亏 {pnl} | 标签 {tag}" + ) + if note: + line += f" | 备注 {note}" + lines.append(line) + return "\n".join(lines) + + +def send_archive_quote_review( + *, + quote_date: str, + content: str, +) -> dict[str, Any]: + text = (content or "").strip() + if not text: + return {"ok": False, "msg": "语录内容不能为空"} + day = (quote_date or "").strip()[:10] + if not day: + return {"ok": False, "msg": "语录日期无效"} + + session = create_new_session( + trading_day=day, + title=f"复盘 {day}", + bot_mode=CHAT_BOT_TRADING, + ) + sid = session["id"] + + archive_payload = list_daily_trades(trading_day=day, period="today") + archive_trades_text = format_archive_trades_for_ai(archive_payload) + user_for_prompt = clip_text(text, CHAT_USER_MESSAGE_MAX_CHARS) + + user_prompt = build_archive_quote_review_prompt( + quote_date=day, + archive_trades_text=archive_trades_text, + user_message=user_for_prompt, + ) + reply = generate_text( + system=CHAT_SYSTEM, + user=user_prompt, + temperature=CHAT_TEMPERATURE, + max_tokens=CHAT_MAX_OUTPUT_TOKENS, + max_continuations=CHAT_MAX_CONTINUATIONS, + ) + if is_ai_error_reply(reply): + delete_chat_session(sid) + return {"ok": False, "msg": reply} + + append_chat_message(sid, "user", text) + session = append_chat_message(sid, "assistant", reply) + refresh_session_rolling_summary( + sid, + prior_summary="", + user_text=text, + assistant_text=reply, + bot_mode=CHAT_BOT_TRADING, + ) + session = get_active_session() or session + return { + "ok": True, + "trading_day": day, + "session": session, + "sessions": list_chat_sessions(), + "reply": reply, + "model": model_label(), + } diff --git a/manual_trading_hub/hub_ai/chat.py b/manual_trading_hub/hub_ai/chat.py index e920633..60c8a1d 100644 --- a/manual_trading_hub/hub_ai/chat.py +++ b/manual_trading_hub/hub_ai/chat.py @@ -1,275 +1,275 @@ -"""中控 AI:单会话聊天(直到用户点击新开)。""" -from __future__ import annotations - -import threading -from typing import Any, Optional - -from hub_ai.attachments import parse_chat_attachments -from hub_ai.client import generate_text, model_label -from hub_ai.config import ( - CHAT_CONTEXT_MAX_CHARS, - CHAT_FOLLOWUP_CONTEXT_MAX_CHARS, - CHAT_HISTORY_MAX_CHARS_PER_MSG, - CHAT_MAX_CONTINUATIONS, - CHAT_MAX_HISTORY_TURNS, - CHAT_MAX_OUTPUT_TOKENS, - CHAT_PROMPT_MAX_CHARS, - CHAT_SUMMARY_EXCERPT_MAX_CHARS, - CHAT_TEMPERATURE, - CHAT_USER_MESSAGE_MAX_CHARS, - trading_day_reset_hour, -) -from hub_trades_lib import current_trading_day -from hub_ai.context import ( - build_chat_context, - format_chat_context_for_chat, - format_chat_position_overview, -) -from hub_ai.prompts import ( - CHAT_GENERAL_SYSTEM, - CHAT_SYSTEM, - build_chat_user_prompt, - build_general_chat_user_prompt, -) -from hub_ai.rolling_summary import refresh_session_rolling_summary -from hub_ai.store import ( - CHAT_BOT_GENERAL, - CHAT_BOT_TRADING, - append_chat_message, - create_new_session, - delete_chat_session, - ensure_active_session, - get_active_session, - list_chat_sessions, - load_chat_store, - set_active_session, - summary_excerpt_for_chat, -) -from hub_ai.text_util import clip_text, is_ai_error_reply - - -def _is_ai_error_reply(text: str) -> bool: - return is_ai_error_reply(text) - - -def _clip_text(text: str, max_chars: int) -> str: - return clip_text(text, max_chars) - - -def _history_lines( - messages: list[dict], - max_turns: int = CHAT_MAX_HISTORY_TURNS, - *, - max_chars_per_msg: int = CHAT_HISTORY_MAX_CHARS_PER_MSG, - total_max_chars: int | None = None, -) -> str: - rows = [m for m in (messages or []) if m.get("role") in ("user", "assistant")] - rows = rows[-max_turns * 2 :] - lines = [] - for m in rows: - role = "用户" if m.get("role") == "user" else "搭档" - content = str(m.get("content") or "").strip() - if m.get("role") == "assistant" and _is_ai_error_reply(content): - continue - att = m.get("attachments") or [] - if att: - names = "、".join(str(a.get("name") or "附件") for a in att[:3]) - content = f"{content} [附件: {names}]".strip() - content = _clip_text(content, max_chars_per_msg) - if content: - lines.append(f"{role}:{content}") - if total_max_chars and total_max_chars > 0: - while lines and len("\n".join(lines)) > total_max_chars: - lines.pop(0) - return "\n".join(lines) - - -def _trading_context_bundle(ctx: dict[str, Any], *, prior_count: int) -> tuple[str, str]: - day = str(ctx.get("trading_day") or (ctx.get("totals") or {}).get("trading_day") or "") - if prior_count <= 0: - brief = format_chat_context_for_chat(ctx, max_chars=CHAT_CONTEXT_MAX_CHARS) - excerpt = summary_excerpt_for_chat(day, max_chars=CHAT_SUMMARY_EXCERPT_MAX_CHARS) - return brief, excerpt - totals = ctx.get("totals") or {} - overview = format_chat_position_overview(ctx) - slim = ( - f"【续聊快照 {day}】平仓盈亏 {totals.get('total_pnl_u')}U | " - f"笔数 {totals.get('closed_count')} | " - f"持仓 {totals.get('open_position_count', 0)} 仓 | " - f"浮盈亏 {totals.get('float_pnl_u')}U" - ) - brief = _clip_text(overview + "\n" + slim, CHAT_FOLLOWUP_CONTEXT_MAX_CHARS) - return brief, "" - - -def _history_budget(*sizes: int) -> int: - used = sum(int(s or 0) for s in sizes) + 2200 - return max(1200, CHAT_PROMPT_MAX_CHARS - used) - - -def _prompt_memory(session: dict, prior_msgs: list[dict]) -> tuple[str, str]: - """续聊优先用滚动摘要;旧会话无摘要时仅带最近 1 轮兜底。""" - rolling = str(session.get("rolling_summary") or "").strip() - if rolling: - return rolling, "" - prior_count = len([m for m in prior_msgs if m.get("role") in ("user", "assistant")]) - if prior_count <= 0: - return "", "" - tail = _history_lines( - prior_msgs, - max_turns=1, - max_chars_per_msg=CHAT_HISTORY_MAX_CHARS_PER_MSG, - ) - return "", tail - - -def get_chat_state() -> dict[str, Any]: - store = load_chat_store() - session = get_active_session() - if session: - session.setdefault("bot_mode", CHAT_BOT_TRADING) - session.setdefault("rolling_summary", "") - return { - "active_session_id": store.get("active_session_id"), - "session": session, - "sessions": list_chat_sessions(), - "model": model_label(), - } - - -def start_new_chat(*, trading_day: str, bot_mode: str = CHAT_BOT_TRADING) -> dict: - session = create_new_session(trading_day=trading_day, bot_mode=bot_mode) - return { - "ok": True, - "session": session, - "sessions": list_chat_sessions(), - "model": model_label(), - } - - -def switch_chat_session(session_id: str) -> dict[str, Any]: - session = set_active_session(session_id) - return { - "ok": True, - "session": session, - "sessions": list_chat_sessions(), - "model": model_label(), - } - - -def remove_chat_session(session_id: str) -> dict[str, Any]: - deleted, new_active = delete_chat_session(session_id) - if not deleted: - return {"ok": False, "msg": "session_not_found"} - session = get_active_session() - return { - "ok": True, - "active_session_id": new_active, - "session": session, - "sessions": list_chat_sessions(), - "model": model_label(), - } - - -def send_chat_message( - exchanges: list[dict], - message: str, - *, - trading_day: str | None = None, - raw_attachments: Optional[list[dict]] = None, -) -> dict[str, Any]: - text = (message or "").strip() - parsed = parse_chat_attachments(raw_attachments or []) - if parsed.get("errors") and not text and not parsed.get("images_b64"): - return {"ok": False, "msg": ";".join(parsed["errors"])} - if not text and not parsed.get("images_b64") and not parsed.get("text_append"): - return {"ok": False, "msg": "消息不能为空"} - - user_visible = text - if parsed.get("text_append"): - user_visible = (user_visible + "\n\n" + parsed["text_append"]).strip() - if not user_visible and parsed.get("attachment_note"): - user_visible = f"(上传了 {parsed['attachment_note']})" - - day = (trading_day or "").strip()[:10] or current_trading_day( - reset_hour=trading_day_reset_hour() - ) - session = ensure_active_session(trading_day=day) - sid = session["id"] - prior_rolling = str(session.get("rolling_summary") or "") - prior_msgs = session.get("messages") or [] - prior_count = len([m for m in prior_msgs if m.get("role") in ("user", "assistant")]) - user_for_prompt = _clip_text(text or user_visible, CHAT_USER_MESSAGE_MAX_CHARS) - rolling_summary, history_tail = _prompt_memory(session, prior_msgs) - - bot_mode = (session.get("bot_mode") or CHAT_BOT_TRADING).strip().lower() - if bot_mode == CHAT_BOT_GENERAL: - user_prompt = build_general_chat_user_prompt( - rolling_summary=rolling_summary, - history_lines=history_tail, - user_message=user_for_prompt, - attachment_note=str(parsed.get("attachment_note") or ""), - ) - if parsed.get("text_append"): - user_prompt += "\n\n【附件正文】\n" + _clip_text(parsed["text_append"], 3000) - system_prompt = CHAT_GENERAL_SYSTEM - else: - ctx = build_chat_context(exchanges, trading_day=day) - day = ctx["trading_day"] - brief_ctx, excerpt = _trading_context_bundle(ctx, prior_count=prior_count) - user_prompt = build_chat_user_prompt( - context_text=brief_ctx, - trading_day=day, - summary_excerpt=excerpt, - rolling_summary=rolling_summary, - history_lines=history_tail, - user_message=user_for_prompt, - attachment_note=str(parsed.get("attachment_note") or ""), - ) - if parsed.get("text_append"): - user_prompt += "\n\n【附件正文】\n" + _clip_text(parsed["text_append"], 3000) - system_prompt = CHAT_SYSTEM - - reply = generate_text( - system=system_prompt, - user=user_prompt, - temperature=CHAT_TEMPERATURE, - images_b64=parsed.get("images_b64") or None, - max_tokens=CHAT_MAX_OUTPUT_TOKENS, - max_continuations=CHAT_MAX_CONTINUATIONS, - ) - if _is_ai_error_reply(reply): - return {"ok": False, "msg": reply, "session_id": sid} - - append_chat_message( - sid, - "user", - user_visible, - attachments=parsed.get("attachment_meta") or [], - ) - session = append_chat_message(sid, "assistant", reply) - summary_kwargs = { - "session_id": sid, - "prior_summary": prior_rolling, - "user_text": user_visible, - "assistant_text": reply, - "bot_mode": bot_mode, - } - - def _refresh_summary_bg() -> None: - try: - refresh_session_rolling_summary(**summary_kwargs) - except Exception: - pass - - threading.Thread(target=_refresh_summary_bg, daemon=True).start() - session = get_active_session() or session - return { - "ok": True, - "trading_day": day, - "session": session, - "sessions": list_chat_sessions(), - "reply": reply, - "model": model_label(), - "attachment_warnings": parsed.get("errors") or [], - } +"""中控 AI:单会话聊天(直到用户点击新开)。""" +from __future__ import annotations + +import threading +from typing import Any, Optional + +from hub_ai.attachments import parse_chat_attachments +from hub_ai.client import generate_text, model_label +from hub_ai.config import ( + CHAT_CONTEXT_MAX_CHARS, + CHAT_FOLLOWUP_CONTEXT_MAX_CHARS, + CHAT_HISTORY_MAX_CHARS_PER_MSG, + CHAT_MAX_CONTINUATIONS, + CHAT_MAX_HISTORY_TURNS, + CHAT_MAX_OUTPUT_TOKENS, + CHAT_PROMPT_MAX_CHARS, + CHAT_SUMMARY_EXCERPT_MAX_CHARS, + CHAT_TEMPERATURE, + CHAT_USER_MESSAGE_MAX_CHARS, + trading_day_reset_hour, +) +from lib.hub.hub_trades_lib import current_trading_day +from hub_ai.context import ( + build_chat_context, + format_chat_context_for_chat, + format_chat_position_overview, +) +from hub_ai.prompts import ( + CHAT_GENERAL_SYSTEM, + CHAT_SYSTEM, + build_chat_user_prompt, + build_general_chat_user_prompt, +) +from hub_ai.rolling_summary import refresh_session_rolling_summary +from hub_ai.store import ( + CHAT_BOT_GENERAL, + CHAT_BOT_TRADING, + append_chat_message, + create_new_session, + delete_chat_session, + ensure_active_session, + get_active_session, + list_chat_sessions, + load_chat_store, + set_active_session, + summary_excerpt_for_chat, +) +from hub_ai.text_util import clip_text, is_ai_error_reply + + +def _is_ai_error_reply(text: str) -> bool: + return is_ai_error_reply(text) + + +def _clip_text(text: str, max_chars: int) -> str: + return clip_text(text, max_chars) + + +def _history_lines( + messages: list[dict], + max_turns: int = CHAT_MAX_HISTORY_TURNS, + *, + max_chars_per_msg: int = CHAT_HISTORY_MAX_CHARS_PER_MSG, + total_max_chars: int | None = None, +) -> str: + rows = [m for m in (messages or []) if m.get("role") in ("user", "assistant")] + rows = rows[-max_turns * 2 :] + lines = [] + for m in rows: + role = "用户" if m.get("role") == "user" else "搭档" + content = str(m.get("content") or "").strip() + if m.get("role") == "assistant" and _is_ai_error_reply(content): + continue + att = m.get("attachments") or [] + if att: + names = "、".join(str(a.get("name") or "附件") for a in att[:3]) + content = f"{content} [附件: {names}]".strip() + content = _clip_text(content, max_chars_per_msg) + if content: + lines.append(f"{role}:{content}") + if total_max_chars and total_max_chars > 0: + while lines and len("\n".join(lines)) > total_max_chars: + lines.pop(0) + return "\n".join(lines) + + +def _trading_context_bundle(ctx: dict[str, Any], *, prior_count: int) -> tuple[str, str]: + day = str(ctx.get("trading_day") or (ctx.get("totals") or {}).get("trading_day") or "") + if prior_count <= 0: + brief = format_chat_context_for_chat(ctx, max_chars=CHAT_CONTEXT_MAX_CHARS) + excerpt = summary_excerpt_for_chat(day, max_chars=CHAT_SUMMARY_EXCERPT_MAX_CHARS) + return brief, excerpt + totals = ctx.get("totals") or {} + overview = format_chat_position_overview(ctx) + slim = ( + f"【续聊快照 {day}】平仓盈亏 {totals.get('total_pnl_u')}U | " + f"笔数 {totals.get('closed_count')} | " + f"持仓 {totals.get('open_position_count', 0)} 仓 | " + f"浮盈亏 {totals.get('float_pnl_u')}U" + ) + brief = _clip_text(overview + "\n" + slim, CHAT_FOLLOWUP_CONTEXT_MAX_CHARS) + return brief, "" + + +def _history_budget(*sizes: int) -> int: + used = sum(int(s or 0) for s in sizes) + 2200 + return max(1200, CHAT_PROMPT_MAX_CHARS - used) + + +def _prompt_memory(session: dict, prior_msgs: list[dict]) -> tuple[str, str]: + """续聊优先用滚动摘要;旧会话无摘要时仅带最近 1 轮兜底。""" + rolling = str(session.get("rolling_summary") or "").strip() + if rolling: + return rolling, "" + prior_count = len([m for m in prior_msgs if m.get("role") in ("user", "assistant")]) + if prior_count <= 0: + return "", "" + tail = _history_lines( + prior_msgs, + max_turns=1, + max_chars_per_msg=CHAT_HISTORY_MAX_CHARS_PER_MSG, + ) + return "", tail + + +def get_chat_state() -> dict[str, Any]: + store = load_chat_store() + session = get_active_session() + if session: + session.setdefault("bot_mode", CHAT_BOT_TRADING) + session.setdefault("rolling_summary", "") + return { + "active_session_id": store.get("active_session_id"), + "session": session, + "sessions": list_chat_sessions(), + "model": model_label(), + } + + +def start_new_chat(*, trading_day: str, bot_mode: str = CHAT_BOT_TRADING) -> dict: + session = create_new_session(trading_day=trading_day, bot_mode=bot_mode) + return { + "ok": True, + "session": session, + "sessions": list_chat_sessions(), + "model": model_label(), + } + + +def switch_chat_session(session_id: str) -> dict[str, Any]: + session = set_active_session(session_id) + return { + "ok": True, + "session": session, + "sessions": list_chat_sessions(), + "model": model_label(), + } + + +def remove_chat_session(session_id: str) -> dict[str, Any]: + deleted, new_active = delete_chat_session(session_id) + if not deleted: + return {"ok": False, "msg": "session_not_found"} + session = get_active_session() + return { + "ok": True, + "active_session_id": new_active, + "session": session, + "sessions": list_chat_sessions(), + "model": model_label(), + } + + +def send_chat_message( + exchanges: list[dict], + message: str, + *, + trading_day: str | None = None, + raw_attachments: Optional[list[dict]] = None, +) -> dict[str, Any]: + text = (message or "").strip() + parsed = parse_chat_attachments(raw_attachments or []) + if parsed.get("errors") and not text and not parsed.get("images_b64"): + return {"ok": False, "msg": ";".join(parsed["errors"])} + if not text and not parsed.get("images_b64") and not parsed.get("text_append"): + return {"ok": False, "msg": "消息不能为空"} + + user_visible = text + if parsed.get("text_append"): + user_visible = (user_visible + "\n\n" + parsed["text_append"]).strip() + if not user_visible and parsed.get("attachment_note"): + user_visible = f"(上传了 {parsed['attachment_note']})" + + day = (trading_day or "").strip()[:10] or current_trading_day( + reset_hour=trading_day_reset_hour() + ) + session = ensure_active_session(trading_day=day) + sid = session["id"] + prior_rolling = str(session.get("rolling_summary") or "") + prior_msgs = session.get("messages") or [] + prior_count = len([m for m in prior_msgs if m.get("role") in ("user", "assistant")]) + user_for_prompt = _clip_text(text or user_visible, CHAT_USER_MESSAGE_MAX_CHARS) + rolling_summary, history_tail = _prompt_memory(session, prior_msgs) + + bot_mode = (session.get("bot_mode") or CHAT_BOT_TRADING).strip().lower() + if bot_mode == CHAT_BOT_GENERAL: + user_prompt = build_general_chat_user_prompt( + rolling_summary=rolling_summary, + history_lines=history_tail, + user_message=user_for_prompt, + attachment_note=str(parsed.get("attachment_note") or ""), + ) + if parsed.get("text_append"): + user_prompt += "\n\n【附件正文】\n" + _clip_text(parsed["text_append"], 3000) + system_prompt = CHAT_GENERAL_SYSTEM + else: + ctx = build_chat_context(exchanges, trading_day=day) + day = ctx["trading_day"] + brief_ctx, excerpt = _trading_context_bundle(ctx, prior_count=prior_count) + user_prompt = build_chat_user_prompt( + context_text=brief_ctx, + trading_day=day, + summary_excerpt=excerpt, + rolling_summary=rolling_summary, + history_lines=history_tail, + user_message=user_for_prompt, + attachment_note=str(parsed.get("attachment_note") or ""), + ) + if parsed.get("text_append"): + user_prompt += "\n\n【附件正文】\n" + _clip_text(parsed["text_append"], 3000) + system_prompt = CHAT_SYSTEM + + reply = generate_text( + system=system_prompt, + user=user_prompt, + temperature=CHAT_TEMPERATURE, + images_b64=parsed.get("images_b64") or None, + max_tokens=CHAT_MAX_OUTPUT_TOKENS, + max_continuations=CHAT_MAX_CONTINUATIONS, + ) + if _is_ai_error_reply(reply): + return {"ok": False, "msg": reply, "session_id": sid} + + append_chat_message( + sid, + "user", + user_visible, + attachments=parsed.get("attachment_meta") or [], + ) + session = append_chat_message(sid, "assistant", reply) + summary_kwargs = { + "session_id": sid, + "prior_summary": prior_rolling, + "user_text": user_visible, + "assistant_text": reply, + "bot_mode": bot_mode, + } + + def _refresh_summary_bg() -> None: + try: + refresh_session_rolling_summary(**summary_kwargs) + except Exception: + pass + + threading.Thread(target=_refresh_summary_bg, daemon=True).start() + session = get_active_session() or session + return { + "ok": True, + "trading_day": day, + "session": session, + "sessions": list_chat_sessions(), + "reply": reply, + "model": model_label(), + "attachment_warnings": parsed.get("errors") or [], + } diff --git a/manual_trading_hub/hub_ai/client.py b/manual_trading_hub/hub_ai/client.py index 3a485e2..2902bdd 100644 --- a/manual_trading_hub/hub_ai/client.py +++ b/manual_trading_hub/hub_ai/client.py @@ -1,42 +1,42 @@ -"""中控 AI 模型调用(共用 ai_client 配置,逻辑独立)。""" -from __future__ import annotations - -import sys -from pathlib import Path -from typing import Optional, Sequence - -_REPO_ROOT = Path(__file__).resolve().parents[2] -if str(_REPO_ROOT) not in sys.path: - sys.path.insert(0, str(_REPO_ROOT)) - -from ai_client import ai_generate, ai_generate_chat, ai_provider_label # noqa: E402 - - -def model_label() -> str: - return ai_provider_label() - - -def generate_text( - *, - system: str, - user: str, - temperature: float, - images_b64: Optional[Sequence[str]] = None, - max_tokens: int | None = None, - max_continuations: int = 3, -) -> str: - if max_tokens is not None and max_tokens > 0: - return ai_generate_chat( - system=system, - user=user, - temperature=temperature, - images_b64=images_b64, - max_tokens=int(max_tokens), - max_continuations=max_continuations, - ) - prompt = f"{system.strip()}\n\n---\n\n{user.strip()}" - return ai_generate( - prompt, - temperature=temperature, - images_b64=images_b64, - ) +"""中控 AI 模型调用(共用 ai_client 配置,逻辑独立)。""" +from __future__ import annotations + +import sys +from pathlib import Path +from typing import Optional, Sequence + +_REPO_ROOT = Path(__file__).resolve().parents[2] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from lib.ai.ai_client import ai_generate, ai_generate_chat, ai_provider_label # noqa: E402 + + +def model_label() -> str: + return ai_provider_label() + + +def generate_text( + *, + system: str, + user: str, + temperature: float, + images_b64: Optional[Sequence[str]] = None, + max_tokens: int | None = None, + max_continuations: int = 3, +) -> str: + if max_tokens is not None and max_tokens > 0: + return ai_generate_chat( + system=system, + user=user, + temperature=temperature, + images_b64=images_b64, + max_tokens=int(max_tokens), + max_continuations=max_continuations, + ) + prompt = f"{system.strip()}\n\n---\n\n{user.strip()}" + return ai_generate( + prompt, + temperature=temperature, + images_b64=images_b64, + ) diff --git a/manual_trading_hub/hub_ai/context.py b/manual_trading_hub/hub_ai/context.py index ce712d5..35a02e5 100644 --- a/manual_trading_hub/hub_ai/context.py +++ b/manual_trading_hub/hub_ai/context.py @@ -1,1070 +1,1070 @@ -"""中控 AI:四户数据聚合为结构化上下文。""" -from __future__ import annotations - -import hashlib -import json -import os -import re -import time -from concurrent.futures import ThreadPoolExecutor -from datetime import datetime, timedelta -from threading import Lock -from typing import Any, Optional - -import httpx - -from hub_ai.config import ( - CHAT_CONTEXT_MAX_CHARS, - FUND_HISTORY_DAYS, - hub_agent_timeout, - hub_flask_timeout, - trading_day_reset_hour, -) -from hub_ai.fund_history import format_fund_history_text, get_fund_history, record_fund_snapshot -from hub_trades_lib import current_trading_day, summarize_trades - -_CHAT_CONTEXT_CACHE: dict[str, dict[str, Any]] = {} -_CHAT_CONTEXT_CACHE_LOCK = Lock() -_HUB_TPSL_MERGE_FN: Any = None - - -def _chat_context_cache_ttl_sec() -> float: - try: - return float(os.getenv("CHAT_CONTEXT_CACHE_TTL_SEC", "45") or "45") - except ValueError: - return 45.0 - - -def _hub_token() -> str: - return (os.getenv("HUB_BRIDGE_TOKEN") or os.getenv("CONTROL_TOKEN") or "").strip() - - -def _hub_headers() -> dict[str, str]: - tok = _hub_token() - return {"X-Hub-Token": tok} if tok else {} - - -def _agent_headers() -> dict[str, str]: - tok = (os.getenv("CONTROL_TOKEN") or os.getenv("HUB_BRIDGE_TOKEN") or "").strip() - return {"X-Control-Token": tok} if tok else {} - - -def _safe_float(v: Any) -> Optional[float]: - try: - if v is None or v == "": - return None - return float(v) - except (TypeError, ValueError): - return None - - -def _position_contracts(p: dict) -> float: - for key in ("contracts", "contracts_signed", "size"): - v = p.get(key) - try: - if v is not None and v != "": - return float(v) - except (TypeError, ValueError): - continue - return 0.0 - - -def _filter_open_positions(positions: list) -> list[dict]: - out: list[dict] = [] - for p in positions or []: - if not isinstance(p, dict): - continue - if abs(_position_contracts(p)) < 1e-12: - continue - out.append(p) - return out - - -def _account_open_position_count(ac: dict) -> int: - return len(_filter_open_positions(ac.get("positions") or [])) - - -def _monitor_counts(ac: dict) -> dict[str, int]: - mon = ac.get("monitor_lines") or {} - return { - "trends": len(mon.get("trends") or []), - "rolls": len(mon.get("rolls") or []), - "keys": len(mon.get("keys") or []), - "orders": len(mon.get("orders") or []), - } - - -def _position_float_pnl(pos: dict) -> float: - for key in ("unrealized_pnl", "unrealizedPnl", "upnl"): - v = _safe_float(pos.get(key)) - if v is not None: - return v - return 0.0 - - -def _collect_open_issues( - *, - monitored: bool, - agent_ok: bool, - flask_ok: bool, - positions: list, - hub_mon: Optional[dict], - day_pnl: float, -) -> list[str]: - issues: list[str] = [] - if not monitored: - return issues - if not agent_ok: - issues.append("Agent 连接异常") - if not flask_ok: - issues.append("Flask 监控连接异常") - if day_pnl < -0.01: - issues.append(f"当日平仓亏损 {day_pnl:.2f}U") - open_positions = _filter_open_positions(positions) - float_pnl = sum(_position_float_pnl(p) for p in open_positions) - if float_pnl < -0.5: - issues.append(f"当前浮亏 {float_pnl:.2f}U") - if isinstance(hub_mon, dict) and hub_mon.get("ok") is not False: - orders = hub_mon.get("orders") or [] - trends = hub_mon.get("trends") or [] - if open_positions and not orders and not trends: - issues.append("交易所有持仓但无本地 active 监控/趋势计划") - return issues - - -def previous_trading_day(trading_day: str) -> str: - day = (trading_day or "").strip()[:10] - if not day: - return day - dt = datetime.strptime(day, "%Y-%m-%d") - return (dt - timedelta(days=1)).strftime("%Y-%m-%d") - - -def _fmt_fund(v: Any) -> str: - n = _safe_float(v) - if n is None: - return "未知" - return f"{n:.2f}U" - - -def _format_trade_line(t: dict, *, day_label: str = "") -> str: - prefix = f"[{day_label}] " if day_label else "" - return ( - f"{prefix}{t.get('symbol')} {t.get('direction')} {t.get('result')} " - f"{t.get('pnl_amount')}U @ {t.get('closed_at') or '?'}" - ) - - -def _monitor_label(item: dict, default: str = "") -> str: - for key in ("monitor_type_label", "monitor_type", "entry_reason", "source_label"): - val = item.get(key) - if val: - return str(val) - return default - - -def _format_monitor_sections(hub_mon: Optional[dict]) -> dict[str, list[str]]: - out = {"trends": [], "orders": [], "keys": [], "rolls": []} - if not isinstance(hub_mon, dict) or hub_mon.get("ok") is False: - return out - for t in hub_mon.get("trends") or []: - if not isinstance(t, dict): - continue - out["trends"].append( - f"{t.get('symbol')} {t.get('direction')} " - f"SL={t.get('stop_loss')} TP={t.get('take_profit')} " - f"补仓区[{t.get('add_lower')}~{t.get('add_upper')}] " - f"状态={t.get('status')}" - ) - for o in hub_mon.get("orders") or []: - if not isinstance(o, dict): - continue - label = _monitor_label(o, "下单监控") - out["orders"].append( - f"{label}: {o.get('symbol')} {o.get('direction')} " - f"触发={o.get('trigger_price')} SL={o.get('stop_loss')} TP={o.get('take_profit')} " - f"状态={o.get('status')}" - ) - for k in hub_mon.get("keys") or []: - if not isinstance(k, dict): - continue - out["keys"].append( - f"关键位: {k.get('symbol')} {k.get('direction')} " - f"上={k.get('upper')} 下={k.get('lower')} 类型={k.get('monitor_type')}" - ) - for r in hub_mon.get("rolls") or []: - if not isinstance(r, dict): - continue - out["rolls"].append( - f"顺势加仓: {r.get('symbol')} {r.get('direction')} " - f"腿数={r.get('leg_count')} SL={r.get('current_stop_loss') or r.get('initial_stop_loss')} " - f"状态={r.get('status')}" - ) - return out - - -_SL_TP_COMBO_RE = re.compile(r"SL=([\d.eE+-]+).*TP=([\d.eE+-]+)", re.I) - - -def _norm_symbol(sym: str) -> str: - s = (sym or "").strip().upper() - if "/" in s: - s = s.split(":")[0].split("/")[0] - return s - - -def _symbols_match(a: str, b: str) -> bool: - na, nb = _norm_symbol(a), _norm_symbol(b) - return bool(na and nb and na == nb) - - -def _pick_tpsl_from_cond(cond: list) -> tuple[Optional[float], Optional[float]]: - sl = tp = None - if not cond: - return sl, tp - sl_o = tp_o = combo = None - for o in cond: - if not isinstance(o, dict): - continue - lbl = str(o.get("label") or "") - if "止盈止损" in lbl: - combo = o - elif lbl.startswith("止损"): - sl_o = o - elif lbl.startswith("止盈"): - tp_o = o - if combo: - lbl = str(combo.get("label") or "") - m = _SL_TP_COMBO_RE.search(lbl) - if m: - sl = _safe_float(m.group(1)) - tp = _safe_float(m.group(2)) - if sl_o and sl is None: - sl = _safe_float(sl_o.get("trigger_price")) - if tp_o and tp is None: - tp = _safe_float(tp_o.get("trigger_price")) - if sl is None: - for o in cond: - if not isinstance(o, dict): - continue - lbl = str(o.get("label") or "") - if "止损" in lbl and "止盈止损" not in lbl: - sl = _safe_float(o.get("trigger_price")) - if sl is not None: - break - if tp is None: - for o in cond: - if not isinstance(o, dict): - continue - lbl = str(o.get("label") or "") - if lbl.startswith("止盈") or ("止盈" in lbl and "止盈止损" not in lbl): - tp = _safe_float(o.get("trigger_price")) - if tp is not None: - break - return sl, tp - - -def _pick_tpsl_from_exchange_tpsl(et: Any) -> tuple[Optional[float], Optional[float]]: - if not isinstance(et, dict): - return None, None - sl = tp = None - slot_sl = et.get("sl") - slot_tp = et.get("tp") - if isinstance(slot_sl, dict): - sl = _safe_float(slot_sl.get("trigger_price")) - if isinstance(slot_tp, dict): - tp = _safe_float(slot_tp.get("trigger_price")) - return sl, tp - - -def _find_plan_tpsl_for_position( - symbol: str, - side: str, - hub_mon: Optional[dict], -) -> tuple[Optional[float], Optional[float], bool]: - """匹配本地监控/趋势计划:sl, tp, tp_is_program_monitored。""" - if not isinstance(hub_mon, dict): - return None, None, False - side_l = (side or "").lower() - for o in hub_mon.get("orders") or []: - if not isinstance(o, dict): - continue - o_sym = o.get("exchange_symbol") or o.get("symbol") or "" - if not _symbols_match(symbol, o_sym): - continue - if (o.get("direction") or "").lower() != side_l: - continue - return ( - _safe_float(o.get("stop_loss")), - _safe_float(o.get("take_profit")), - False, - ) - for t in hub_mon.get("trends") or []: - if not isinstance(t, dict): - continue - if not _symbols_match(symbol, t.get("symbol") or ""): - continue - if (t.get("direction") or "").lower() != side_l: - continue - plan_tp = t.get("take_profit") - tp = _safe_float(plan_tp) if plan_tp not in (None, "") else None - return _safe_float(t.get("stop_loss")), tp, tp is None - return None, None, False - - -def _resolve_position_tpsl(pos: dict, hub_mon: Optional[dict]) -> dict[str, Any]: - cond = pos.get("conditional_orders") or [] - cond_sl, cond_tp = _pick_tpsl_from_cond(cond) - et_sl, et_tp = _pick_tpsl_from_exchange_tpsl(pos.get("exchange_tpsl")) - plan_sl, plan_tp, tp_monitored = _find_plan_tpsl_for_position( - str(pos.get("symbol") or ""), - str(pos.get("side") or ""), - hub_mon, - ) - sl = cond_sl if cond_sl is not None else et_sl if et_sl is not None else plan_sl - tp_note = "" - tp: Optional[float] = None - if tp_monitored and cond_tp is None and et_tp is None: - tp_note = "程序监控" - else: - tp = cond_tp if cond_tp is not None else et_tp if et_tp is not None else plan_tp - if sl is not None and tp is not None and sl == tp: - tp = None - return {"sl": sl, "tp": tp, "tp_note": tp_note} - - -def _format_position_detail_line(pos: dict, hub_mon: Optional[dict]) -> str: - sym = pos.get("symbol") or "?" - side = pos.get("side") or "?" - contracts = pos.get("contracts") or pos.get("size") or "?" - upnl = _position_float_pnl(pos) - entry = _safe_float(pos.get("entry_price")) - tpsl = _resolve_position_tpsl(pos, hub_mon) - parts = [f"{sym} {side} 张数{contracts}"] - if entry is not None: - parts.append(f"入场{entry:g}") - if tpsl["sl"] is not None: - parts.append(f"止损{tpsl['sl']:g}") - else: - parts.append("止损=未检测到") - if tpsl["tp_note"]: - parts.append(f"止盈={tpsl['tp_note']}") - elif tpsl["tp"] is not None: - parts.append(f"止盈{tpsl['tp']:g}") - else: - parts.append("止盈=未检测到") - parts.append(f"浮盈亏{upnl:.4f}U") - return " - " + " ".join(parts) - - -def _enrich_positions_exchange_tpsl( - positions: list, - price_snap: Optional[dict], - hub_mon: Optional[dict], -) -> None: - global _HUB_TPSL_MERGE_FN - if not positions: - return - if _HUB_TPSL_MERGE_FN is None: - try: - from hub import _merge_flask_exchange_tpsl - - _HUB_TPSL_MERGE_FN = _merge_flask_exchange_tpsl - except Exception: - _HUB_TPSL_MERGE_FN = False - if not _HUB_TPSL_MERGE_FN: - return - try: - _HUB_TPSL_MERGE_FN( - {"agent": {"positions": positions}}, - price_snap if isinstance(price_snap, dict) else None, - hub_mon if isinstance(hub_mon, dict) else None, - ) - except Exception: - pass - - -def _fetch_account_bundle( - client: httpx.Client, - ex: dict, - trading_day: str, - *, - for_chat: bool = False, -) -> dict[str, Any]: - name = ex.get("name") or ex.get("key") or ex.get("id") - key = ex.get("key") or "" - enabled = bool(ex.get("enabled")) - env_disabled = bool(ex.get("env_disabled")) - monitored = enabled and not env_disabled - - base: dict[str, Any] = { - "id": ex.get("id"), - "key": key, - "name": name, - "enabled": enabled, - "env_disabled": env_disabled, - "status": "未监控" if not monitored else "已监控", - "trades": [], - "trade_stats": summarize_trades([]), - "positions": [], - "open_position_count": 0, - "float_pnl_u": 0.0, - "balance_usdt": None, - "funding_usdt": None, - "trading_usdt": None, - "available_trading_usdt": None, - "trades_yesterday": [], - "trade_stats_yesterday": summarize_trades([]), - "monitor_lines": {"trends": [], "orders": [], "keys": [], "rolls": []}, - "issues": [], - "agent_ok": False, - "flask_ok": False, - "hub_monitor": None, - "active_orders": 0, - "active_trends": 0, - } - if not monitored: - base["issues"] = [] - return base - - agent_url = (ex.get("agent_url") or "").rstrip("/") - flask_url = (ex.get("flask_url") or "").rstrip("/") - agent_body = None - if agent_url: - try: - r = client.get( - f"{agent_url}/status", - headers=_agent_headers(), - timeout=hub_agent_timeout(), - ) - if r.status_code == 200: - agent_body = r.json() - base["agent_ok"] = True - except Exception as exc: - base["issues"].append(f"Agent: {exc}") - - if isinstance(agent_body, dict): - base["balance_usdt"] = _safe_float(agent_body.get("balance_usdt")) - positions = agent_body.get("positions") or [] - if isinstance(positions, list): - open_positions = _filter_open_positions(positions) - base["positions"] = open_positions - base["open_position_count"] = len(open_positions) - base["float_pnl_u"] = round(sum(_position_float_pnl(p) for p in open_positions), 4) - - hub_mon = None - price_snap = None - prev_day = previous_trading_day(trading_day) - if flask_url: - try: - r = client.get( - f"{flask_url}/api/hub/account", - headers=_hub_headers(), - timeout=hub_flask_timeout(), - ) - if r.status_code == 200: - acct_body = r.json() - if isinstance(acct_body, dict) and acct_body.get("ok"): - base["funding_usdt"] = _safe_float(acct_body.get("funding_usdt")) - base["trading_usdt"] = _safe_float(acct_body.get("trading_usdt")) - base["available_trading_usdt"] = _safe_float(acct_body.get("available_trading_usdt")) - base["flask_ok"] = True - except Exception as exc: - base["issues"].append(f"资金接口: {exc}") - - try: - r = client.get( - f"{flask_url}/api/hub/trades/today", - headers=_hub_headers(), - params={"trading_day": trading_day}, - timeout=hub_flask_timeout(), - ) - if r.status_code == 200: - trades_body = r.json() - if isinstance(trades_body, dict) and trades_body.get("ok"): - base["trades"] = trades_body.get("trades") or [] - base["trade_stats"] = trades_body.get("stats") or summarize_trades(base["trades"]) - base["flask_ok"] = True - except Exception as exc: - base["issues"].append(f"成交接口: {exc}") - - if prev_day and not for_chat: - try: - r = client.get( - f"{flask_url}/api/hub/trades/today", - headers=_hub_headers(), - params={"trading_day": prev_day}, - timeout=hub_flask_timeout(), - ) - if r.status_code == 200: - y_body = r.json() - if isinstance(y_body, dict) and y_body.get("ok"): - base["trades_yesterday"] = y_body.get("trades") or [] - base["trade_stats_yesterday"] = y_body.get("stats") or summarize_trades( - base["trades_yesterday"] - ) - base["flask_ok"] = True - except Exception as exc: - base["issues"].append(f"昨日成交: {exc}") - - try: - r = client.get( - f"{flask_url}/api/hub/monitor", - headers=_hub_headers(), - timeout=hub_flask_timeout(), - ) - if r.status_code == 200: - hub_mon = r.json() - if isinstance(hub_mon, dict) and hub_mon.get("ok") is not False: - base["hub_monitor"] = hub_mon - base["flask_ok"] = True - base["active_orders"] = len(hub_mon.get("orders") or []) - base["active_trends"] = len(hub_mon.get("trends") or []) - base["monitor_lines"] = _format_monitor_sections(hub_mon) - except Exception as exc: - if "成交接口" not in str(base["issues"]): - base["issues"].append(f"监控接口: {exc}") - - try: - r = client.get( - f"{flask_url}/api/price_snapshot", - headers=_hub_headers(), - timeout=hub_flask_timeout(), - ) - if r.status_code == 200: - body = r.json() - if isinstance(body, dict): - price_snap = body - base["flask_ok"] = True - except Exception: - pass - - if base["positions"]: - _enrich_positions_exchange_tpsl(base["positions"], price_snap, hub_mon) - - if monitored and not base["agent_ok"] and not base["flask_ok"]: - base["status"] = "连接异常" - elif base["issues"]: - base["status"] = "已监控·需关注" - - day_pnl = float((base.get("trade_stats") or {}).get("total_pnl_u") or 0) - base["issues"].extend( - _collect_open_issues( - monitored=monitored, - agent_ok=base["agent_ok"], - flask_ok=base["flask_ok"], - positions=base["positions"], - hub_mon=hub_mon if isinstance(hub_mon, dict) else None, - day_pnl=day_pnl, - ) - ) - base["issues"] = list(dict.fromkeys(base["issues"])) - return base - - -def _fetch_account_bundle_isolated(ex: dict, trading_day: str, *, for_chat: bool) -> dict[str, Any]: - with httpx.Client() as client: - return _fetch_account_bundle(client, ex, trading_day, for_chat=for_chat) - - -def build_daily_context( - exchanges: list[dict], - *, - trading_day: Optional[str] = None, - for_chat: bool = False, -) -> dict[str, Any]: - day = (trading_day or "").strip()[:10] or current_trading_day( - reset_hour=trading_day_reset_hour() - ) - ex_list = exchanges or [] - if for_chat and len(ex_list) > 1: - workers = min(4, len(ex_list)) - with ThreadPoolExecutor(max_workers=workers) as pool: - accounts = list( - pool.map( - lambda ex: _fetch_account_bundle_isolated(ex, day, for_chat=True), - ex_list, - ) - ) - else: - with httpx.Client() as client: - accounts = [ - _fetch_account_bundle(client, ex, day, for_chat=for_chat) for ex in ex_list - ] - - total_closed_pnl = 0.0 - total_closed = total_win = total_loss = 0 - total_float = 0.0 - total_funding = 0.0 - total_trading = 0.0 - total_open_positions = 0 - funding_known = trading_known = 0 - for ac in accounts: - if ac.get("status") == "未监控": - continue - st = ac.get("trade_stats") or {} - total_closed_pnl += float(st.get("total_pnl_u") or 0) - total_closed += int(st.get("closed_count") or 0) - total_win += int(st.get("win_count") or 0) - total_loss += int(st.get("loss_count") or 0) - total_float += float(ac.get("float_pnl_u") or 0) - total_open_positions += int(ac.get("open_position_count") or _account_open_position_count(ac)) - fu = _safe_float(ac.get("funding_usdt")) - tu = _safe_float(ac.get("trading_usdt")) - if fu is not None: - total_funding += fu - funding_known += 1 - if tu is not None: - total_trading += tu - trading_known += 1 - if not funding_known: - total_funding = None - if not trading_known: - total_trading = None - - totals = { - "trading_day": day, - "prev_trading_day": previous_trading_day(day), - "total_pnl_u": round(total_closed_pnl, 4), - "closed_count": total_closed, - "win_count": total_win, - "loss_count": total_loss, - "float_pnl_u": round(total_float, 4), - "open_position_count": total_open_positions, - "total_funding_usdt": round(total_funding, 4) if total_funding is not None else None, - "total_trading_usdt": round(total_trading, 4) if total_trading is not None else None, - } - if for_chat: - fund_history: list = [] - fund_history_text = "" - else: - snap_accounts = [ - { - **ac, - "monitored": ac.get("status") != "未监控", - } - for ac in accounts - ] - record_fund_snapshot(day, snap_accounts, keep_days=FUND_HISTORY_DAYS) - fund_history = get_fund_history(anchor_day=day, keep_days=FUND_HISTORY_DAYS) - account_names = {str(ac.get("key") or ac.get("id")): ac.get("name") for ac in accounts} - fund_history_text = format_fund_history_text(fund_history, account_names=account_names) - payload = { - "trading_day": day, - "prev_trading_day": previous_trading_day(day), - "totals": totals, - "accounts": accounts, - "fund_history": fund_history, - "fund_history_text": fund_history_text, - } - if for_chat: - text = format_chat_context_for_chat(payload) - else: - text = format_context_text(payload) - digest = hashlib.sha256(text.encode("utf-8")).hexdigest()[:16] - return { - "trading_day": day, - "prev_trading_day": previous_trading_day(day), - "totals": totals, - "accounts": accounts, - "fund_history": fund_history, - "fund_history_text": fund_history_text, - "text": text, - "context_hash": digest, - } - - -def build_chat_context( - exchanges: list[dict], - *, - trading_day: Optional[str] = None, - force_refresh: bool = False, -) -> dict[str, Any]: - """聊天专用上下文:并行拉取、跳过资金曲线/昨日成交,短 TTL 缓存。""" - day = (trading_day or "").strip()[:10] or current_trading_day( - reset_hour=trading_day_reset_hour() - ) - ttl = _chat_context_cache_ttl_sec() - now = time.monotonic() - if not force_refresh and ttl > 0: - with _CHAT_CONTEXT_CACHE_LOCK: - hit = _CHAT_CONTEXT_CACHE.get(day) - if hit and (now - float(hit.get("ts") or 0)) < ttl: - return hit["ctx"] - ctx = build_daily_context(exchanges, trading_day=day, for_chat=True) - if ttl > 0: - with _CHAT_CONTEXT_CACHE_LOCK: - _CHAT_CONTEXT_CACHE[day] = {"ts": now, "ctx": ctx} - return ctx - - -def format_context_text(payload: dict) -> str: - lines = [] - totals = payload.get("totals") or {} - day = totals.get("trading_day") - prev_day = totals.get("prev_trading_day") or previous_trading_day(str(day or "")) - lines.append( - f"【合计·今日 {day}】平仓盈亏 {totals.get('total_pnl_u')}U | " - f"笔数 {totals.get('closed_count')}(胜{totals.get('win_count')}/负{totals.get('loss_count')})| " - f"实盘持仓 {totals.get('open_position_count', 0)} 仓 | " - f"浮盈亏 {totals.get('float_pnl_u')}U | " - f"资金账户合计 {_fmt_fund(totals.get('total_funding_usdt'))} | " - f"交易账户合计 {_fmt_fund(totals.get('total_trading_usdt'))}" - ) - lines.append( - f"【对比交易日】昨日={prev_day},今日={day}。" - "「持仓」= 交易所 Agent 实盘;「趋势/关键位/监控单/加仓」= 本地计划,不等于已开仓。" - ) - fund_txt = str(payload.get("fund_history_text") or "").strip() - if fund_txt: - lines.append("") - lines.append(fund_txt) - lines.append("") - for ac in payload.get("accounts") or []: - st = ac.get("trade_stats") or {} - sty = ac.get("trade_stats_yesterday") or {} - lines.append(f"--- 账户:{ac.get('name')} ({ac.get('key')}) ---") - lines.append(f"状态:{ac.get('status')}") - if ac.get("status") == "未监控": - lines.append("") - continue - lines.append( - f"资金账户 {_fmt_fund(ac.get('funding_usdt'))} | " - f"交易账户 {_fmt_fund(ac.get('trading_usdt'))} | " - f"可用 {_fmt_fund(ac.get('available_trading_usdt'))}" - ) - lines.append( - f"今日({day})平仓:{st.get('closed_count')} 笔,盈亏 {st.get('total_pnl_u')}U " - f"(胜{st.get('win_count')}/负{st.get('loss_count')})" - ) - lines.append( - f"昨日({prev_day})平仓:{sty.get('closed_count')} 笔,盈亏 {sty.get('total_pnl_u')}U " - f"(胜{sty.get('win_count')}/负{sty.get('loss_count')})" - ) - open_n = int(ac.get("open_position_count") or _account_open_position_count(ac)) - if open_n <= 0: - lines.append("当前交易所持仓:无(空仓)") - else: - lines.append( - f"当前交易所持仓:{open_n} 仓 | 浮盈亏合计 {ac.get('float_pnl_u')}U" - ) - mon = ac.get("monitor_lines") or {} - if mon.get("trends"): - lines.append("趋势回调计划(本地,非持仓):") - for row in mon["trends"][:8]: - lines.append(f" - {row}") - if mon.get("rolls"): - lines.append("顺势加仓(本地,非持仓):") - for row in mon["rolls"][:8]: - lines.append(f" - {row}") - if mon.get("keys"): - lines.append("关键位监控(本地,非持仓):") - for row in mon["keys"][:8]: - lines.append(f" - {row}") - if mon.get("orders"): - lines.append("进行中的下单监控(本地,非持仓):") - for row in mon["orders"][:8]: - lines.append(f" - {row}") - positions = ac.get("positions") or [] - hub_mon = ac.get("hub_monitor") - if positions: - lines.append("持仓明细(交易所实盘,含止盈止损若已挂):") - for p in positions[:8]: - if not isinstance(p, dict): - continue - lines.append(_format_position_detail_line(p, hub_mon)) - lines.append( - f"Agent合约余额:{ac.get('balance_usdt') if ac.get('balance_usdt') is not None else '未知'} USDT" - ) - trades_today = ac.get("trades") or [] - if trades_today: - lines.append(f"今日平仓明细:") - for t in trades_today[:15]: - lines.append(f" - {_format_trade_line(t)}") - trades_y = ac.get("trades_yesterday") or [] - if trades_y: - lines.append(f"昨日平仓明细:") - for t in trades_y[:15]: - lines.append(f" - {_format_trade_line(t)}") - if not trades_today and not trades_y: - lines.append("平仓明细:无") - issues = ac.get("issues") or [] - if issues: - lines.append("关注点:" + ";".join(issues)) - lines.append("") - return "\n".join(lines).strip() - - -def format_summary_context_text(payload: dict) -> str: - """今日总结专用:仅当日平仓/持仓/监控,不含昨日明细与资金走势。""" - lines = [] - totals = payload.get("totals") or {} - day = totals.get("trading_day") - lines.append( - f"【合计·今日 {day}】平仓盈亏 {totals.get('total_pnl_u')}U | " - f"笔数 {totals.get('closed_count')}(胜{totals.get('win_count')}/负{totals.get('loss_count')})| " - f"实盘持仓 {totals.get('open_position_count', 0)} 仓 | " - f"浮盈亏 {totals.get('float_pnl_u')}U | " - f"资金账户合计 {_fmt_fund(totals.get('total_funding_usdt'))} | " - f"交易账户合计 {_fmt_fund(totals.get('total_trading_usdt'))}" - ) - lines.append( - f"【说明】交易日={day}。" - "「持仓」= 交易所 Agent 实盘;「趋势/关键位/监控单/加仓」= 本地计划,不等于已开仓。" - ) - lines.append("") - for ac in payload.get("accounts") or []: - st = ac.get("trade_stats") or {} - lines.append(f"--- 账户:{ac.get('name')} ({ac.get('key')}) ---") - lines.append(f"状态:{ac.get('status')}") - if ac.get("status") == "未监控": - lines.append("") - continue - lines.append( - f"资金账户 {_fmt_fund(ac.get('funding_usdt'))} | " - f"交易账户 {_fmt_fund(ac.get('trading_usdt'))} | " - f"可用 {_fmt_fund(ac.get('available_trading_usdt'))}" - ) - lines.append( - f"今日({day})平仓:{st.get('closed_count')} 笔,盈亏 {st.get('total_pnl_u')}U " - f"(胜{st.get('win_count')}/负{st.get('loss_count')})" - ) - open_n = int(ac.get("open_position_count") or _account_open_position_count(ac)) - if open_n <= 0: - lines.append("当前交易所持仓:无(空仓)") - else: - lines.append( - f"当前交易所持仓:{open_n} 仓 | 浮盈亏合计 {ac.get('float_pnl_u')}U" - ) - mon = ac.get("monitor_lines") or {} - if mon.get("trends"): - lines.append("趋势回调计划(本地,非持仓):") - for row in mon["trends"][:8]: - lines.append(f" - {row}") - if mon.get("rolls"): - lines.append("顺势加仓(本地,非持仓):") - for row in mon["rolls"][:8]: - lines.append(f" - {row}") - if mon.get("keys"): - lines.append("关键位监控(本地,非持仓):") - for row in mon["keys"][:8]: - lines.append(f" - {row}") - if mon.get("orders"): - lines.append("进行中的下单监控(本地,非持仓):") - for row in mon["orders"][:8]: - lines.append(f" - {row}") - positions = ac.get("positions") or [] - hub_mon = ac.get("hub_monitor") - if positions: - lines.append("持仓明细(交易所实盘,含止盈止损若已挂):") - for p in positions[:8]: - if not isinstance(p, dict): - continue - lines.append(_format_position_detail_line(p, hub_mon)) - lines.append( - f"Agent合约余额:{ac.get('balance_usdt') if ac.get('balance_usdt') is not None else '未知'} USDT" - ) - trades_today = ac.get("trades") or [] - if trades_today: - lines.append("今日平仓明细:") - for t in trades_today[:15]: - lines.append(f" - {_format_trade_line(t)}") - else: - lines.append("今日平仓明细:无") - issues = ac.get("issues") or [] - if issues: - lines.append("关注点:" + ";".join(issues)) - lines.append("") - return "\n".join(lines).strip() - - -def summary_context_hash(payload: dict) -> str: - text = format_summary_context_text(payload) - return hashlib.sha256(text.encode("utf-8")).hexdigest()[:16] - - -def format_account_remark(ac: dict) -> str: - """分户表格备注:监控摘要 + 持仓。""" - parts: list[str] = [] - mon = ac.get("monitor_lines") or {} - if mon.get("trends"): - parts.append(f"趋势{len(mon['trends'])}") - if mon.get("rolls"): - parts.append(f"加仓{len(mon['rolls'])}") - if mon.get("keys"): - parts.append(f"关键位{len(mon['keys'])}") - if mon.get("orders"): - parts.append(f"监控单{len(mon['orders'])}") - positions = ac.get("positions") or [] - if positions: - for p in positions[:2]: - if not isinstance(p, dict): - continue - sym = p.get("symbol") or "?" - side = p.get("side") or "?" - upnl = _position_float_pnl(p) - parts.append(f"{sym} {side} 浮{upnl:.2f}U") - if len(positions) > 2: - parts.append(f"+{len(positions) - 2}仓") - if not parts: - issues = ac.get("issues") or [] - if issues: - return ";".join(str(x) for x in issues[:2]) - return "无" - return ";".join(parts) - - -def format_dashboard_account_detail(ac: dict) -> dict[str, Any]: - """数据看板分户卡片:监控仅数量,持仓逐行(含浮盈亏)。""" - mon = ac.get("monitor_lines") or {} - position_lines: list[dict[str, Any]] = [] - for p in _filter_open_positions(ac.get("positions") or []): - sym = p.get("symbol") or "?" - side = p.get("side") or "?" - upnl = _position_float_pnl(p) - position_lines.append( - { - "kind": "position", - "text": f"{sym} {side}", - "pnl": round(upnl, 4), - } - ) - issues = [str(x) for x in (ac.get("issues") or [])[:3]] - return { - "monitor_counts": { - "keys": len(mon.get("keys") or []), - "orders": len(mon.get("orders") or []), - "trends": len(mon.get("trends") or []), - "rolls": len(mon.get("rolls") or []), - }, - "position_lines": position_lines, - "issues": issues, - } - - -def collect_closed_trades_snapshot( - accounts: list[dict], - *, - today: str, - yesterday: str | None = None, -) -> list[dict]: - rows: list[dict] = [] - for ac in accounts or []: - name = ac.get("name") or ac.get("key") - if yesterday: - for t in ac.get("trades_yesterday") or []: - if not isinstance(t, dict): - continue - rows.append({**t, "account_name": name, "trading_day": yesterday}) - for t in ac.get("trades") or []: - if not isinstance(t, dict): - continue - rows.append({**t, "account_name": name, "trading_day": today}) - rows.sort(key=lambda x: str(x.get("closed_at") or x.get("opened_at") or ""), reverse=True) - return rows[:80] - - -def format_chat_position_overview(payload: dict) -> str: - totals = payload.get("totals") or {} - total_open = int(totals.get("open_position_count") or 0) - if total_open <= 0: - head = f"【实盘持仓总览】当前空仓(监控户合计 0 仓)。浮盈亏 0U 表示无持仓,不是「有仓但不动」。" - else: - head = ( - f"【实盘持仓总览】监控户合计 {total_open} 仓," - f"浮盈亏合计 {totals.get('float_pnl_u')}U。" - ) - lines = [ - head, - "【区分】只有带「持仓明细/交易所实盘」字样的才是已开仓;趋势回调、关键位、下单监控、顺势加仓是本地计划/监控,不算持仓。持仓明细若含止损/止盈价,表示已挂条件单或监控计划中有价位。", - ] - for ac in payload.get("accounts") or []: - if ac.get("status") == "未监控": - continue - n = int(ac.get("open_position_count") or _account_open_position_count(ac)) - mc = _monitor_counts(ac) - mon_parts = [] - if mc["trends"]: - mon_parts.append(f"趋势{mc['trends']}") - if mc["rolls"]: - mon_parts.append(f"加仓{mc['rolls']}") - if mc["keys"]: - mon_parts.append(f"关键位{mc['keys']}") - if mc["orders"]: - mon_parts.append(f"监控单{mc['orders']}") - mon_txt = f";本地监控 {' '.join(mon_parts)}" if mon_parts else "" - if n <= 0: - lines.append(f"- {ac.get('name')}:空仓{mon_txt}") - else: - lines.append( - f"- {ac.get('name')}:{n}仓 浮盈亏{ac.get('float_pnl_u')}U{mon_txt}" - ) - return "\n".join(lines) - - -def format_chat_context_slim(payload: dict) -> str: - """聊天专用:不含 180 日资金曲线与昨日平仓明细,避免挤占对话上下文。""" - totals = payload.get("totals") or {} - day = totals.get("trading_day") - lines = [ - f"【今日合计 {day}】平仓盈亏 {totals.get('total_pnl_u')}U | " - f"笔数 {totals.get('closed_count')}(胜{totals.get('win_count')}/负{totals.get('loss_count')})| " - f"实盘持仓 {totals.get('open_position_count', 0)} 仓 | 浮盈亏 {totals.get('float_pnl_u')}U", - "【说明】持仓=交易所实盘;趋势/关键位/监控单=本地计划,不等于已开仓。持仓行内「止损/止盈」= 交易所条件单或监控计划价(与监控页一致)。", - ] - for ac in payload.get("accounts") or []: - if ac.get("status") == "未监控": - lines.append(f"- {ac.get('name')}:未监控") - continue - st = ac.get("trade_stats") or {} - open_n = int(ac.get("open_position_count") or _account_open_position_count(ac)) - pos_txt = "空仓" if open_n <= 0 else f"{open_n}仓 浮盈亏{ac.get('float_pnl_u')}U" - mc = _monitor_counts(ac) - mon = [] - if mc["trends"]: - mon.append(f"趋势{mc['trends']}") - if mc["rolls"]: - mon.append(f"加仓{mc['rolls']}") - if mc["keys"]: - mon.append(f"关键位{mc['keys']}") - if mc["orders"]: - mon.append(f"监控单{mc['orders']}") - mon_txt = f";监控 {'/'.join(mon)}" if mon else "" - lines.append( - f"- {ac.get('name')}:{pos_txt} | 今日盈亏{st.get('total_pnl_u')}U " - f"({st.get('closed_count')}笔) | 资金{_fmt_fund(ac.get('funding_usdt'))} " - f"交易{_fmt_fund(ac.get('trading_usdt'))}{mon_txt}" - ) - trades = ac.get("trades") or [] - if trades: - for t in trades[:4]: - lines.append(f" · {_format_trade_line(t)}") - if len(trades) > 4: - lines.append(f" · …共{len(trades)}笔今日平仓") - positions = ac.get("positions") or [] - hub_mon = ac.get("hub_monitor") - for p in positions[:4]: - if not isinstance(p, dict): - continue - lines.append(f" · {_format_position_detail_line(p, hub_mon).lstrip(' - ')}") - return "\n".join(lines) - - -def format_chat_context_for_chat( - payload: dict, - max_chars: int = CHAT_CONTEXT_MAX_CHARS, -) -> str: - overview = format_chat_position_overview(payload) - body = format_chat_context_slim(payload) - text = overview + "\n\n" + body - if len(text) <= max_chars: - return text - budget = max(2000, max_chars - len(overview) - 4) - return overview + "\n\n" + body[:budget].rstrip() + "…" - - -def format_chat_context_brief( - payload: dict, - max_chars: int = CHAT_CONTEXT_MAX_CHARS, -) -> str: - return format_chat_context_for_chat(payload, max_chars=max_chars) +"""中控 AI:四户数据聚合为结构化上下文。""" +from __future__ import annotations + +import hashlib +import json +import os +import re +import time +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, timedelta +from threading import Lock +from typing import Any, Optional + +import httpx + +from hub_ai.config import ( + CHAT_CONTEXT_MAX_CHARS, + FUND_HISTORY_DAYS, + hub_agent_timeout, + hub_flask_timeout, + trading_day_reset_hour, +) +from hub_ai.fund_history import format_fund_history_text, get_fund_history, record_fund_snapshot +from lib.hub.hub_trades_lib import current_trading_day, summarize_trades + +_CHAT_CONTEXT_CACHE: dict[str, dict[str, Any]] = {} +_CHAT_CONTEXT_CACHE_LOCK = Lock() +_HUB_TPSL_MERGE_FN: Any = None + + +def _chat_context_cache_ttl_sec() -> float: + try: + return float(os.getenv("CHAT_CONTEXT_CACHE_TTL_SEC", "45") or "45") + except ValueError: + return 45.0 + + +def _hub_token() -> str: + return (os.getenv("HUB_BRIDGE_TOKEN") or os.getenv("CONTROL_TOKEN") or "").strip() + + +def _hub_headers() -> dict[str, str]: + tok = _hub_token() + return {"X-Hub-Token": tok} if tok else {} + + +def _agent_headers() -> dict[str, str]: + tok = (os.getenv("CONTROL_TOKEN") or os.getenv("HUB_BRIDGE_TOKEN") or "").strip() + return {"X-Control-Token": tok} if tok else {} + + +def _safe_float(v: Any) -> Optional[float]: + try: + if v is None or v == "": + return None + return float(v) + except (TypeError, ValueError): + return None + + +def _position_contracts(p: dict) -> float: + for key in ("contracts", "contracts_signed", "size"): + v = p.get(key) + try: + if v is not None and v != "": + return float(v) + except (TypeError, ValueError): + continue + return 0.0 + + +def _filter_open_positions(positions: list) -> list[dict]: + out: list[dict] = [] + for p in positions or []: + if not isinstance(p, dict): + continue + if abs(_position_contracts(p)) < 1e-12: + continue + out.append(p) + return out + + +def _account_open_position_count(ac: dict) -> int: + return len(_filter_open_positions(ac.get("positions") or [])) + + +def _monitor_counts(ac: dict) -> dict[str, int]: + mon = ac.get("monitor_lines") or {} + return { + "trends": len(mon.get("trends") or []), + "rolls": len(mon.get("rolls") or []), + "keys": len(mon.get("keys") or []), + "orders": len(mon.get("orders") or []), + } + + +def _position_float_pnl(pos: dict) -> float: + for key in ("unrealized_pnl", "unrealizedPnl", "upnl"): + v = _safe_float(pos.get(key)) + if v is not None: + return v + return 0.0 + + +def _collect_open_issues( + *, + monitored: bool, + agent_ok: bool, + flask_ok: bool, + positions: list, + hub_mon: Optional[dict], + day_pnl: float, +) -> list[str]: + issues: list[str] = [] + if not monitored: + return issues + if not agent_ok: + issues.append("Agent 连接异常") + if not flask_ok: + issues.append("Flask 监控连接异常") + if day_pnl < -0.01: + issues.append(f"当日平仓亏损 {day_pnl:.2f}U") + open_positions = _filter_open_positions(positions) + float_pnl = sum(_position_float_pnl(p) for p in open_positions) + if float_pnl < -0.5: + issues.append(f"当前浮亏 {float_pnl:.2f}U") + if isinstance(hub_mon, dict) and hub_mon.get("ok") is not False: + orders = hub_mon.get("orders") or [] + trends = hub_mon.get("trends") or [] + if open_positions and not orders and not trends: + issues.append("交易所有持仓但无本地 active 监控/趋势计划") + return issues + + +def previous_trading_day(trading_day: str) -> str: + day = (trading_day or "").strip()[:10] + if not day: + return day + dt = datetime.strptime(day, "%Y-%m-%d") + return (dt - timedelta(days=1)).strftime("%Y-%m-%d") + + +def _fmt_fund(v: Any) -> str: + n = _safe_float(v) + if n is None: + return "未知" + return f"{n:.2f}U" + + +def _format_trade_line(t: dict, *, day_label: str = "") -> str: + prefix = f"[{day_label}] " if day_label else "" + return ( + f"{prefix}{t.get('symbol')} {t.get('direction')} {t.get('result')} " + f"{t.get('pnl_amount')}U @ {t.get('closed_at') or '?'}" + ) + + +def _monitor_label(item: dict, default: str = "") -> str: + for key in ("monitor_type_label", "monitor_type", "entry_reason", "source_label"): + val = item.get(key) + if val: + return str(val) + return default + + +def _format_monitor_sections(hub_mon: Optional[dict]) -> dict[str, list[str]]: + out = {"trends": [], "orders": [], "keys": [], "rolls": []} + if not isinstance(hub_mon, dict) or hub_mon.get("ok") is False: + return out + for t in hub_mon.get("trends") or []: + if not isinstance(t, dict): + continue + out["trends"].append( + f"{t.get('symbol')} {t.get('direction')} " + f"SL={t.get('stop_loss')} TP={t.get('take_profit')} " + f"补仓区[{t.get('add_lower')}~{t.get('add_upper')}] " + f"状态={t.get('status')}" + ) + for o in hub_mon.get("orders") or []: + if not isinstance(o, dict): + continue + label = _monitor_label(o, "下单监控") + out["orders"].append( + f"{label}: {o.get('symbol')} {o.get('direction')} " + f"触发={o.get('trigger_price')} SL={o.get('stop_loss')} TP={o.get('take_profit')} " + f"状态={o.get('status')}" + ) + for k in hub_mon.get("keys") or []: + if not isinstance(k, dict): + continue + out["keys"].append( + f"关键位: {k.get('symbol')} {k.get('direction')} " + f"上={k.get('upper')} 下={k.get('lower')} 类型={k.get('monitor_type')}" + ) + for r in hub_mon.get("rolls") or []: + if not isinstance(r, dict): + continue + out["rolls"].append( + f"顺势加仓: {r.get('symbol')} {r.get('direction')} " + f"腿数={r.get('leg_count')} SL={r.get('current_stop_loss') or r.get('initial_stop_loss')} " + f"状态={r.get('status')}" + ) + return out + + +_SL_TP_COMBO_RE = re.compile(r"SL=([\d.eE+-]+).*TP=([\d.eE+-]+)", re.I) + + +def _norm_symbol(sym: str) -> str: + s = (sym or "").strip().upper() + if "/" in s: + s = s.split(":")[0].split("/")[0] + return s + + +def _symbols_match(a: str, b: str) -> bool: + na, nb = _norm_symbol(a), _norm_symbol(b) + return bool(na and nb and na == nb) + + +def _pick_tpsl_from_cond(cond: list) -> tuple[Optional[float], Optional[float]]: + sl = tp = None + if not cond: + return sl, tp + sl_o = tp_o = combo = None + for o in cond: + if not isinstance(o, dict): + continue + lbl = str(o.get("label") or "") + if "止盈止损" in lbl: + combo = o + elif lbl.startswith("止损"): + sl_o = o + elif lbl.startswith("止盈"): + tp_o = o + if combo: + lbl = str(combo.get("label") or "") + m = _SL_TP_COMBO_RE.search(lbl) + if m: + sl = _safe_float(m.group(1)) + tp = _safe_float(m.group(2)) + if sl_o and sl is None: + sl = _safe_float(sl_o.get("trigger_price")) + if tp_o and tp is None: + tp = _safe_float(tp_o.get("trigger_price")) + if sl is None: + for o in cond: + if not isinstance(o, dict): + continue + lbl = str(o.get("label") or "") + if "止损" in lbl and "止盈止损" not in lbl: + sl = _safe_float(o.get("trigger_price")) + if sl is not None: + break + if tp is None: + for o in cond: + if not isinstance(o, dict): + continue + lbl = str(o.get("label") or "") + if lbl.startswith("止盈") or ("止盈" in lbl and "止盈止损" not in lbl): + tp = _safe_float(o.get("trigger_price")) + if tp is not None: + break + return sl, tp + + +def _pick_tpsl_from_exchange_tpsl(et: Any) -> tuple[Optional[float], Optional[float]]: + if not isinstance(et, dict): + return None, None + sl = tp = None + slot_sl = et.get("sl") + slot_tp = et.get("tp") + if isinstance(slot_sl, dict): + sl = _safe_float(slot_sl.get("trigger_price")) + if isinstance(slot_tp, dict): + tp = _safe_float(slot_tp.get("trigger_price")) + return sl, tp + + +def _find_plan_tpsl_for_position( + symbol: str, + side: str, + hub_mon: Optional[dict], +) -> tuple[Optional[float], Optional[float], bool]: + """匹配本地监控/趋势计划:sl, tp, tp_is_program_monitored。""" + if not isinstance(hub_mon, dict): + return None, None, False + side_l = (side or "").lower() + for o in hub_mon.get("orders") or []: + if not isinstance(o, dict): + continue + o_sym = o.get("exchange_symbol") or o.get("symbol") or "" + if not _symbols_match(symbol, o_sym): + continue + if (o.get("direction") or "").lower() != side_l: + continue + return ( + _safe_float(o.get("stop_loss")), + _safe_float(o.get("take_profit")), + False, + ) + for t in hub_mon.get("trends") or []: + if not isinstance(t, dict): + continue + if not _symbols_match(symbol, t.get("symbol") or ""): + continue + if (t.get("direction") or "").lower() != side_l: + continue + plan_tp = t.get("take_profit") + tp = _safe_float(plan_tp) if plan_tp not in (None, "") else None + return _safe_float(t.get("stop_loss")), tp, tp is None + return None, None, False + + +def _resolve_position_tpsl(pos: dict, hub_mon: Optional[dict]) -> dict[str, Any]: + cond = pos.get("conditional_orders") or [] + cond_sl, cond_tp = _pick_tpsl_from_cond(cond) + et_sl, et_tp = _pick_tpsl_from_exchange_tpsl(pos.get("exchange_tpsl")) + plan_sl, plan_tp, tp_monitored = _find_plan_tpsl_for_position( + str(pos.get("symbol") or ""), + str(pos.get("side") or ""), + hub_mon, + ) + sl = cond_sl if cond_sl is not None else et_sl if et_sl is not None else plan_sl + tp_note = "" + tp: Optional[float] = None + if tp_monitored and cond_tp is None and et_tp is None: + tp_note = "程序监控" + else: + tp = cond_tp if cond_tp is not None else et_tp if et_tp is not None else plan_tp + if sl is not None and tp is not None and sl == tp: + tp = None + return {"sl": sl, "tp": tp, "tp_note": tp_note} + + +def _format_position_detail_line(pos: dict, hub_mon: Optional[dict]) -> str: + sym = pos.get("symbol") or "?" + side = pos.get("side") or "?" + contracts = pos.get("contracts") or pos.get("size") or "?" + upnl = _position_float_pnl(pos) + entry = _safe_float(pos.get("entry_price")) + tpsl = _resolve_position_tpsl(pos, hub_mon) + parts = [f"{sym} {side} 张数{contracts}"] + if entry is not None: + parts.append(f"入场{entry:g}") + if tpsl["sl"] is not None: + parts.append(f"止损{tpsl['sl']:g}") + else: + parts.append("止损=未检测到") + if tpsl["tp_note"]: + parts.append(f"止盈={tpsl['tp_note']}") + elif tpsl["tp"] is not None: + parts.append(f"止盈{tpsl['tp']:g}") + else: + parts.append("止盈=未检测到") + parts.append(f"浮盈亏{upnl:.4f}U") + return " - " + " ".join(parts) + + +def _enrich_positions_exchange_tpsl( + positions: list, + price_snap: Optional[dict], + hub_mon: Optional[dict], +) -> None: + global _HUB_TPSL_MERGE_FN + if not positions: + return + if _HUB_TPSL_MERGE_FN is None: + try: + from hub import _merge_flask_exchange_tpsl + + _HUB_TPSL_MERGE_FN = _merge_flask_exchange_tpsl + except Exception: + _HUB_TPSL_MERGE_FN = False + if not _HUB_TPSL_MERGE_FN: + return + try: + _HUB_TPSL_MERGE_FN( + {"agent": {"positions": positions}}, + price_snap if isinstance(price_snap, dict) else None, + hub_mon if isinstance(hub_mon, dict) else None, + ) + except Exception: + pass + + +def _fetch_account_bundle( + client: httpx.Client, + ex: dict, + trading_day: str, + *, + for_chat: bool = False, +) -> dict[str, Any]: + name = ex.get("name") or ex.get("key") or ex.get("id") + key = ex.get("key") or "" + enabled = bool(ex.get("enabled")) + env_disabled = bool(ex.get("env_disabled")) + monitored = enabled and not env_disabled + + base: dict[str, Any] = { + "id": ex.get("id"), + "key": key, + "name": name, + "enabled": enabled, + "env_disabled": env_disabled, + "status": "未监控" if not monitored else "已监控", + "trades": [], + "trade_stats": summarize_trades([]), + "positions": [], + "open_position_count": 0, + "float_pnl_u": 0.0, + "balance_usdt": None, + "funding_usdt": None, + "trading_usdt": None, + "available_trading_usdt": None, + "trades_yesterday": [], + "trade_stats_yesterday": summarize_trades([]), + "monitor_lines": {"trends": [], "orders": [], "keys": [], "rolls": []}, + "issues": [], + "agent_ok": False, + "flask_ok": False, + "hub_monitor": None, + "active_orders": 0, + "active_trends": 0, + } + if not monitored: + base["issues"] = [] + return base + + agent_url = (ex.get("agent_url") or "").rstrip("/") + flask_url = (ex.get("flask_url") or "").rstrip("/") + agent_body = None + if agent_url: + try: + r = client.get( + f"{agent_url}/status", + headers=_agent_headers(), + timeout=hub_agent_timeout(), + ) + if r.status_code == 200: + agent_body = r.json() + base["agent_ok"] = True + except Exception as exc: + base["issues"].append(f"Agent: {exc}") + + if isinstance(agent_body, dict): + base["balance_usdt"] = _safe_float(agent_body.get("balance_usdt")) + positions = agent_body.get("positions") or [] + if isinstance(positions, list): + open_positions = _filter_open_positions(positions) + base["positions"] = open_positions + base["open_position_count"] = len(open_positions) + base["float_pnl_u"] = round(sum(_position_float_pnl(p) for p in open_positions), 4) + + hub_mon = None + price_snap = None + prev_day = previous_trading_day(trading_day) + if flask_url: + try: + r = client.get( + f"{flask_url}/api/hub/account", + headers=_hub_headers(), + timeout=hub_flask_timeout(), + ) + if r.status_code == 200: + acct_body = r.json() + if isinstance(acct_body, dict) and acct_body.get("ok"): + base["funding_usdt"] = _safe_float(acct_body.get("funding_usdt")) + base["trading_usdt"] = _safe_float(acct_body.get("trading_usdt")) + base["available_trading_usdt"] = _safe_float(acct_body.get("available_trading_usdt")) + base["flask_ok"] = True + except Exception as exc: + base["issues"].append(f"资金接口: {exc}") + + try: + r = client.get( + f"{flask_url}/api/hub/trades/today", + headers=_hub_headers(), + params={"trading_day": trading_day}, + timeout=hub_flask_timeout(), + ) + if r.status_code == 200: + trades_body = r.json() + if isinstance(trades_body, dict) and trades_body.get("ok"): + base["trades"] = trades_body.get("trades") or [] + base["trade_stats"] = trades_body.get("stats") or summarize_trades(base["trades"]) + base["flask_ok"] = True + except Exception as exc: + base["issues"].append(f"成交接口: {exc}") + + if prev_day and not for_chat: + try: + r = client.get( + f"{flask_url}/api/hub/trades/today", + headers=_hub_headers(), + params={"trading_day": prev_day}, + timeout=hub_flask_timeout(), + ) + if r.status_code == 200: + y_body = r.json() + if isinstance(y_body, dict) and y_body.get("ok"): + base["trades_yesterday"] = y_body.get("trades") or [] + base["trade_stats_yesterday"] = y_body.get("stats") or summarize_trades( + base["trades_yesterday"] + ) + base["flask_ok"] = True + except Exception as exc: + base["issues"].append(f"昨日成交: {exc}") + + try: + r = client.get( + f"{flask_url}/api/hub/monitor", + headers=_hub_headers(), + timeout=hub_flask_timeout(), + ) + if r.status_code == 200: + hub_mon = r.json() + if isinstance(hub_mon, dict) and hub_mon.get("ok") is not False: + base["hub_monitor"] = hub_mon + base["flask_ok"] = True + base["active_orders"] = len(hub_mon.get("orders") or []) + base["active_trends"] = len(hub_mon.get("trends") or []) + base["monitor_lines"] = _format_monitor_sections(hub_mon) + except Exception as exc: + if "成交接口" not in str(base["issues"]): + base["issues"].append(f"监控接口: {exc}") + + try: + r = client.get( + f"{flask_url}/api/price_snapshot", + headers=_hub_headers(), + timeout=hub_flask_timeout(), + ) + if r.status_code == 200: + body = r.json() + if isinstance(body, dict): + price_snap = body + base["flask_ok"] = True + except Exception: + pass + + if base["positions"]: + _enrich_positions_exchange_tpsl(base["positions"], price_snap, hub_mon) + + if monitored and not base["agent_ok"] and not base["flask_ok"]: + base["status"] = "连接异常" + elif base["issues"]: + base["status"] = "已监控·需关注" + + day_pnl = float((base.get("trade_stats") or {}).get("total_pnl_u") or 0) + base["issues"].extend( + _collect_open_issues( + monitored=monitored, + agent_ok=base["agent_ok"], + flask_ok=base["flask_ok"], + positions=base["positions"], + hub_mon=hub_mon if isinstance(hub_mon, dict) else None, + day_pnl=day_pnl, + ) + ) + base["issues"] = list(dict.fromkeys(base["issues"])) + return base + + +def _fetch_account_bundle_isolated(ex: dict, trading_day: str, *, for_chat: bool) -> dict[str, Any]: + with httpx.Client() as client: + return _fetch_account_bundle(client, ex, trading_day, for_chat=for_chat) + + +def build_daily_context( + exchanges: list[dict], + *, + trading_day: Optional[str] = None, + for_chat: bool = False, +) -> dict[str, Any]: + day = (trading_day or "").strip()[:10] or current_trading_day( + reset_hour=trading_day_reset_hour() + ) + ex_list = exchanges or [] + if for_chat and len(ex_list) > 1: + workers = min(4, len(ex_list)) + with ThreadPoolExecutor(max_workers=workers) as pool: + accounts = list( + pool.map( + lambda ex: _fetch_account_bundle_isolated(ex, day, for_chat=True), + ex_list, + ) + ) + else: + with httpx.Client() as client: + accounts = [ + _fetch_account_bundle(client, ex, day, for_chat=for_chat) for ex in ex_list + ] + + total_closed_pnl = 0.0 + total_closed = total_win = total_loss = 0 + total_float = 0.0 + total_funding = 0.0 + total_trading = 0.0 + total_open_positions = 0 + funding_known = trading_known = 0 + for ac in accounts: + if ac.get("status") == "未监控": + continue + st = ac.get("trade_stats") or {} + total_closed_pnl += float(st.get("total_pnl_u") or 0) + total_closed += int(st.get("closed_count") or 0) + total_win += int(st.get("win_count") or 0) + total_loss += int(st.get("loss_count") or 0) + total_float += float(ac.get("float_pnl_u") or 0) + total_open_positions += int(ac.get("open_position_count") or _account_open_position_count(ac)) + fu = _safe_float(ac.get("funding_usdt")) + tu = _safe_float(ac.get("trading_usdt")) + if fu is not None: + total_funding += fu + funding_known += 1 + if tu is not None: + total_trading += tu + trading_known += 1 + if not funding_known: + total_funding = None + if not trading_known: + total_trading = None + + totals = { + "trading_day": day, + "prev_trading_day": previous_trading_day(day), + "total_pnl_u": round(total_closed_pnl, 4), + "closed_count": total_closed, + "win_count": total_win, + "loss_count": total_loss, + "float_pnl_u": round(total_float, 4), + "open_position_count": total_open_positions, + "total_funding_usdt": round(total_funding, 4) if total_funding is not None else None, + "total_trading_usdt": round(total_trading, 4) if total_trading is not None else None, + } + if for_chat: + fund_history: list = [] + fund_history_text = "" + else: + snap_accounts = [ + { + **ac, + "monitored": ac.get("status") != "未监控", + } + for ac in accounts + ] + record_fund_snapshot(day, snap_accounts, keep_days=FUND_HISTORY_DAYS) + fund_history = get_fund_history(anchor_day=day, keep_days=FUND_HISTORY_DAYS) + account_names = {str(ac.get("key") or ac.get("id")): ac.get("name") for ac in accounts} + fund_history_text = format_fund_history_text(fund_history, account_names=account_names) + payload = { + "trading_day": day, + "prev_trading_day": previous_trading_day(day), + "totals": totals, + "accounts": accounts, + "fund_history": fund_history, + "fund_history_text": fund_history_text, + } + if for_chat: + text = format_chat_context_for_chat(payload) + else: + text = format_context_text(payload) + digest = hashlib.sha256(text.encode("utf-8")).hexdigest()[:16] + return { + "trading_day": day, + "prev_trading_day": previous_trading_day(day), + "totals": totals, + "accounts": accounts, + "fund_history": fund_history, + "fund_history_text": fund_history_text, + "text": text, + "context_hash": digest, + } + + +def build_chat_context( + exchanges: list[dict], + *, + trading_day: Optional[str] = None, + force_refresh: bool = False, +) -> dict[str, Any]: + """聊天专用上下文:并行拉取、跳过资金曲线/昨日成交,短 TTL 缓存。""" + day = (trading_day or "").strip()[:10] or current_trading_day( + reset_hour=trading_day_reset_hour() + ) + ttl = _chat_context_cache_ttl_sec() + now = time.monotonic() + if not force_refresh and ttl > 0: + with _CHAT_CONTEXT_CACHE_LOCK: + hit = _CHAT_CONTEXT_CACHE.get(day) + if hit and (now - float(hit.get("ts") or 0)) < ttl: + return hit["ctx"] + ctx = build_daily_context(exchanges, trading_day=day, for_chat=True) + if ttl > 0: + with _CHAT_CONTEXT_CACHE_LOCK: + _CHAT_CONTEXT_CACHE[day] = {"ts": now, "ctx": ctx} + return ctx + + +def format_context_text(payload: dict) -> str: + lines = [] + totals = payload.get("totals") or {} + day = totals.get("trading_day") + prev_day = totals.get("prev_trading_day") or previous_trading_day(str(day or "")) + lines.append( + f"【合计·今日 {day}】平仓盈亏 {totals.get('total_pnl_u')}U | " + f"笔数 {totals.get('closed_count')}(胜{totals.get('win_count')}/负{totals.get('loss_count')})| " + f"实盘持仓 {totals.get('open_position_count', 0)} 仓 | " + f"浮盈亏 {totals.get('float_pnl_u')}U | " + f"资金账户合计 {_fmt_fund(totals.get('total_funding_usdt'))} | " + f"交易账户合计 {_fmt_fund(totals.get('total_trading_usdt'))}" + ) + lines.append( + f"【对比交易日】昨日={prev_day},今日={day}。" + "「持仓」= 交易所 Agent 实盘;「趋势/关键位/监控单/加仓」= 本地计划,不等于已开仓。" + ) + fund_txt = str(payload.get("fund_history_text") or "").strip() + if fund_txt: + lines.append("") + lines.append(fund_txt) + lines.append("") + for ac in payload.get("accounts") or []: + st = ac.get("trade_stats") or {} + sty = ac.get("trade_stats_yesterday") or {} + lines.append(f"--- 账户:{ac.get('name')} ({ac.get('key')}) ---") + lines.append(f"状态:{ac.get('status')}") + if ac.get("status") == "未监控": + lines.append("") + continue + lines.append( + f"资金账户 {_fmt_fund(ac.get('funding_usdt'))} | " + f"交易账户 {_fmt_fund(ac.get('trading_usdt'))} | " + f"可用 {_fmt_fund(ac.get('available_trading_usdt'))}" + ) + lines.append( + f"今日({day})平仓:{st.get('closed_count')} 笔,盈亏 {st.get('total_pnl_u')}U " + f"(胜{st.get('win_count')}/负{st.get('loss_count')})" + ) + lines.append( + f"昨日({prev_day})平仓:{sty.get('closed_count')} 笔,盈亏 {sty.get('total_pnl_u')}U " + f"(胜{sty.get('win_count')}/负{sty.get('loss_count')})" + ) + open_n = int(ac.get("open_position_count") or _account_open_position_count(ac)) + if open_n <= 0: + lines.append("当前交易所持仓:无(空仓)") + else: + lines.append( + f"当前交易所持仓:{open_n} 仓 | 浮盈亏合计 {ac.get('float_pnl_u')}U" + ) + mon = ac.get("monitor_lines") or {} + if mon.get("trends"): + lines.append("趋势回调计划(本地,非持仓):") + for row in mon["trends"][:8]: + lines.append(f" - {row}") + if mon.get("rolls"): + lines.append("顺势加仓(本地,非持仓):") + for row in mon["rolls"][:8]: + lines.append(f" - {row}") + if mon.get("keys"): + lines.append("关键位监控(本地,非持仓):") + for row in mon["keys"][:8]: + lines.append(f" - {row}") + if mon.get("orders"): + lines.append("进行中的下单监控(本地,非持仓):") + for row in mon["orders"][:8]: + lines.append(f" - {row}") + positions = ac.get("positions") or [] + hub_mon = ac.get("hub_monitor") + if positions: + lines.append("持仓明细(交易所实盘,含止盈止损若已挂):") + for p in positions[:8]: + if not isinstance(p, dict): + continue + lines.append(_format_position_detail_line(p, hub_mon)) + lines.append( + f"Agent合约余额:{ac.get('balance_usdt') if ac.get('balance_usdt') is not None else '未知'} USDT" + ) + trades_today = ac.get("trades") or [] + if trades_today: + lines.append(f"今日平仓明细:") + for t in trades_today[:15]: + lines.append(f" - {_format_trade_line(t)}") + trades_y = ac.get("trades_yesterday") or [] + if trades_y: + lines.append(f"昨日平仓明细:") + for t in trades_y[:15]: + lines.append(f" - {_format_trade_line(t)}") + if not trades_today and not trades_y: + lines.append("平仓明细:无") + issues = ac.get("issues") or [] + if issues: + lines.append("关注点:" + ";".join(issues)) + lines.append("") + return "\n".join(lines).strip() + + +def format_summary_context_text(payload: dict) -> str: + """今日总结专用:仅当日平仓/持仓/监控,不含昨日明细与资金走势。""" + lines = [] + totals = payload.get("totals") or {} + day = totals.get("trading_day") + lines.append( + f"【合计·今日 {day}】平仓盈亏 {totals.get('total_pnl_u')}U | " + f"笔数 {totals.get('closed_count')}(胜{totals.get('win_count')}/负{totals.get('loss_count')})| " + f"实盘持仓 {totals.get('open_position_count', 0)} 仓 | " + f"浮盈亏 {totals.get('float_pnl_u')}U | " + f"资金账户合计 {_fmt_fund(totals.get('total_funding_usdt'))} | " + f"交易账户合计 {_fmt_fund(totals.get('total_trading_usdt'))}" + ) + lines.append( + f"【说明】交易日={day}。" + "「持仓」= 交易所 Agent 实盘;「趋势/关键位/监控单/加仓」= 本地计划,不等于已开仓。" + ) + lines.append("") + for ac in payload.get("accounts") or []: + st = ac.get("trade_stats") or {} + lines.append(f"--- 账户:{ac.get('name')} ({ac.get('key')}) ---") + lines.append(f"状态:{ac.get('status')}") + if ac.get("status") == "未监控": + lines.append("") + continue + lines.append( + f"资金账户 {_fmt_fund(ac.get('funding_usdt'))} | " + f"交易账户 {_fmt_fund(ac.get('trading_usdt'))} | " + f"可用 {_fmt_fund(ac.get('available_trading_usdt'))}" + ) + lines.append( + f"今日({day})平仓:{st.get('closed_count')} 笔,盈亏 {st.get('total_pnl_u')}U " + f"(胜{st.get('win_count')}/负{st.get('loss_count')})" + ) + open_n = int(ac.get("open_position_count") or _account_open_position_count(ac)) + if open_n <= 0: + lines.append("当前交易所持仓:无(空仓)") + else: + lines.append( + f"当前交易所持仓:{open_n} 仓 | 浮盈亏合计 {ac.get('float_pnl_u')}U" + ) + mon = ac.get("monitor_lines") or {} + if mon.get("trends"): + lines.append("趋势回调计划(本地,非持仓):") + for row in mon["trends"][:8]: + lines.append(f" - {row}") + if mon.get("rolls"): + lines.append("顺势加仓(本地,非持仓):") + for row in mon["rolls"][:8]: + lines.append(f" - {row}") + if mon.get("keys"): + lines.append("关键位监控(本地,非持仓):") + for row in mon["keys"][:8]: + lines.append(f" - {row}") + if mon.get("orders"): + lines.append("进行中的下单监控(本地,非持仓):") + for row in mon["orders"][:8]: + lines.append(f" - {row}") + positions = ac.get("positions") or [] + hub_mon = ac.get("hub_monitor") + if positions: + lines.append("持仓明细(交易所实盘,含止盈止损若已挂):") + for p in positions[:8]: + if not isinstance(p, dict): + continue + lines.append(_format_position_detail_line(p, hub_mon)) + lines.append( + f"Agent合约余额:{ac.get('balance_usdt') if ac.get('balance_usdt') is not None else '未知'} USDT" + ) + trades_today = ac.get("trades") or [] + if trades_today: + lines.append("今日平仓明细:") + for t in trades_today[:15]: + lines.append(f" - {_format_trade_line(t)}") + else: + lines.append("今日平仓明细:无") + issues = ac.get("issues") or [] + if issues: + lines.append("关注点:" + ";".join(issues)) + lines.append("") + return "\n".join(lines).strip() + + +def summary_context_hash(payload: dict) -> str: + text = format_summary_context_text(payload) + return hashlib.sha256(text.encode("utf-8")).hexdigest()[:16] + + +def format_account_remark(ac: dict) -> str: + """分户表格备注:监控摘要 + 持仓。""" + parts: list[str] = [] + mon = ac.get("monitor_lines") or {} + if mon.get("trends"): + parts.append(f"趋势{len(mon['trends'])}") + if mon.get("rolls"): + parts.append(f"加仓{len(mon['rolls'])}") + if mon.get("keys"): + parts.append(f"关键位{len(mon['keys'])}") + if mon.get("orders"): + parts.append(f"监控单{len(mon['orders'])}") + positions = ac.get("positions") or [] + if positions: + for p in positions[:2]: + if not isinstance(p, dict): + continue + sym = p.get("symbol") or "?" + side = p.get("side") or "?" + upnl = _position_float_pnl(p) + parts.append(f"{sym} {side} 浮{upnl:.2f}U") + if len(positions) > 2: + parts.append(f"+{len(positions) - 2}仓") + if not parts: + issues = ac.get("issues") or [] + if issues: + return ";".join(str(x) for x in issues[:2]) + return "无" + return ";".join(parts) + + +def format_dashboard_account_detail(ac: dict) -> dict[str, Any]: + """数据看板分户卡片:监控仅数量,持仓逐行(含浮盈亏)。""" + mon = ac.get("monitor_lines") or {} + position_lines: list[dict[str, Any]] = [] + for p in _filter_open_positions(ac.get("positions") or []): + sym = p.get("symbol") or "?" + side = p.get("side") or "?" + upnl = _position_float_pnl(p) + position_lines.append( + { + "kind": "position", + "text": f"{sym} {side}", + "pnl": round(upnl, 4), + } + ) + issues = [str(x) for x in (ac.get("issues") or [])[:3]] + return { + "monitor_counts": { + "keys": len(mon.get("keys") or []), + "orders": len(mon.get("orders") or []), + "trends": len(mon.get("trends") or []), + "rolls": len(mon.get("rolls") or []), + }, + "position_lines": position_lines, + "issues": issues, + } + + +def collect_closed_trades_snapshot( + accounts: list[dict], + *, + today: str, + yesterday: str | None = None, +) -> list[dict]: + rows: list[dict] = [] + for ac in accounts or []: + name = ac.get("name") or ac.get("key") + if yesterday: + for t in ac.get("trades_yesterday") or []: + if not isinstance(t, dict): + continue + rows.append({**t, "account_name": name, "trading_day": yesterday}) + for t in ac.get("trades") or []: + if not isinstance(t, dict): + continue + rows.append({**t, "account_name": name, "trading_day": today}) + rows.sort(key=lambda x: str(x.get("closed_at") or x.get("opened_at") or ""), reverse=True) + return rows[:80] + + +def format_chat_position_overview(payload: dict) -> str: + totals = payload.get("totals") or {} + total_open = int(totals.get("open_position_count") or 0) + if total_open <= 0: + head = f"【实盘持仓总览】当前空仓(监控户合计 0 仓)。浮盈亏 0U 表示无持仓,不是「有仓但不动」。" + else: + head = ( + f"【实盘持仓总览】监控户合计 {total_open} 仓," + f"浮盈亏合计 {totals.get('float_pnl_u')}U。" + ) + lines = [ + head, + "【区分】只有带「持仓明细/交易所实盘」字样的才是已开仓;趋势回调、关键位、下单监控、顺势加仓是本地计划/监控,不算持仓。持仓明细若含止损/止盈价,表示已挂条件单或监控计划中有价位。", + ] + for ac in payload.get("accounts") or []: + if ac.get("status") == "未监控": + continue + n = int(ac.get("open_position_count") or _account_open_position_count(ac)) + mc = _monitor_counts(ac) + mon_parts = [] + if mc["trends"]: + mon_parts.append(f"趋势{mc['trends']}") + if mc["rolls"]: + mon_parts.append(f"加仓{mc['rolls']}") + if mc["keys"]: + mon_parts.append(f"关键位{mc['keys']}") + if mc["orders"]: + mon_parts.append(f"监控单{mc['orders']}") + mon_txt = f";本地监控 {' '.join(mon_parts)}" if mon_parts else "" + if n <= 0: + lines.append(f"- {ac.get('name')}:空仓{mon_txt}") + else: + lines.append( + f"- {ac.get('name')}:{n}仓 浮盈亏{ac.get('float_pnl_u')}U{mon_txt}" + ) + return "\n".join(lines) + + +def format_chat_context_slim(payload: dict) -> str: + """聊天专用:不含 180 日资金曲线与昨日平仓明细,避免挤占对话上下文。""" + totals = payload.get("totals") or {} + day = totals.get("trading_day") + lines = [ + f"【今日合计 {day}】平仓盈亏 {totals.get('total_pnl_u')}U | " + f"笔数 {totals.get('closed_count')}(胜{totals.get('win_count')}/负{totals.get('loss_count')})| " + f"实盘持仓 {totals.get('open_position_count', 0)} 仓 | 浮盈亏 {totals.get('float_pnl_u')}U", + "【说明】持仓=交易所实盘;趋势/关键位/监控单=本地计划,不等于已开仓。持仓行内「止损/止盈」= 交易所条件单或监控计划价(与监控页一致)。", + ] + for ac in payload.get("accounts") or []: + if ac.get("status") == "未监控": + lines.append(f"- {ac.get('name')}:未监控") + continue + st = ac.get("trade_stats") or {} + open_n = int(ac.get("open_position_count") or _account_open_position_count(ac)) + pos_txt = "空仓" if open_n <= 0 else f"{open_n}仓 浮盈亏{ac.get('float_pnl_u')}U" + mc = _monitor_counts(ac) + mon = [] + if mc["trends"]: + mon.append(f"趋势{mc['trends']}") + if mc["rolls"]: + mon.append(f"加仓{mc['rolls']}") + if mc["keys"]: + mon.append(f"关键位{mc['keys']}") + if mc["orders"]: + mon.append(f"监控单{mc['orders']}") + mon_txt = f";监控 {'/'.join(mon)}" if mon else "" + lines.append( + f"- {ac.get('name')}:{pos_txt} | 今日盈亏{st.get('total_pnl_u')}U " + f"({st.get('closed_count')}笔) | 资金{_fmt_fund(ac.get('funding_usdt'))} " + f"交易{_fmt_fund(ac.get('trading_usdt'))}{mon_txt}" + ) + trades = ac.get("trades") or [] + if trades: + for t in trades[:4]: + lines.append(f" · {_format_trade_line(t)}") + if len(trades) > 4: + lines.append(f" · …共{len(trades)}笔今日平仓") + positions = ac.get("positions") or [] + hub_mon = ac.get("hub_monitor") + for p in positions[:4]: + if not isinstance(p, dict): + continue + lines.append(f" · {_format_position_detail_line(p, hub_mon).lstrip(' - ')}") + return "\n".join(lines) + + +def format_chat_context_for_chat( + payload: dict, + max_chars: int = CHAT_CONTEXT_MAX_CHARS, +) -> str: + overview = format_chat_position_overview(payload) + body = format_chat_context_slim(payload) + text = overview + "\n\n" + body + if len(text) <= max_chars: + return text + budget = max(2000, max_chars - len(overview) - 4) + return overview + "\n\n" + body[:budget].rstrip() + "…" + + +def format_chat_context_brief( + payload: dict, + max_chars: int = CHAT_CONTEXT_MAX_CHARS, +) -> str: + return format_chat_context_for_chat(payload, max_chars=max_chars) diff --git a/manual_trading_hub/hub_ai/fund_history.py b/manual_trading_hub/hub_ai/fund_history.py index c40a883..8eda595 100644 --- a/manual_trading_hub/hub_ai/fund_history.py +++ b/manual_trading_hub/hub_ai/fund_history.py @@ -1,18 +1,18 @@ -"""中控 AI:分户资金快照(委托 hub_fund_history_lib,保留 180 交易日)。""" -from __future__ import annotations - -from typing import Any, Optional - -from hub_fund_history_lib import ( - FUND_HISTORY_DAYS, - format_fund_history_text, - get_fund_history, - record_fund_snapshot, -) - -__all__ = [ - "FUND_HISTORY_DAYS", - "format_fund_history_text", - "get_fund_history", - "record_fund_snapshot", -] +"""中控 AI:分户资金快照(委托 hub_fund_history_lib,保留 180 交易日)。""" +from __future__ import annotations + +from typing import Any, Optional + +from lib.hub.hub_fund_history_lib import ( + FUND_HISTORY_DAYS, + format_fund_history_text, + get_fund_history, + record_fund_snapshot, +) + +__all__ = [ + "FUND_HISTORY_DAYS", + "format_fund_history_text", + "get_fund_history", + "record_fund_snapshot", +] diff --git a/manual_trading_hub/hub_ai/routes.py b/manual_trading_hub/hub_ai/routes.py index 1367a90..3b60f2c 100644 --- a/manual_trading_hub/hub_ai/routes.py +++ b/manual_trading_hub/hub_ai/routes.py @@ -1,200 +1,200 @@ -"""中控 AI FastAPI 路由。""" -from __future__ import annotations - -import asyncio -from typing import Callable - -from fastapi import APIRouter, Body, File, Form, HTTPException, UploadFile -from pydantic import BaseModel, Field - -from hub_ai.archive_quote import send_archive_quote_review -from hub_ai.chat import ( - get_chat_state, - remove_chat_session, - send_chat_message, - start_new_chat, - switch_chat_session, -) -from hub_ai.client import model_label -from hub_ai.config import trading_day_reset_hour -from hub_ai.context import build_daily_context -from hub_ai.store import get_latest_summary, list_summaries -from hub_ai.supervisor import send_supervisor_chat -from hub_ai.supervisor_store import get_supervisor_session_state -from hub_ai.summary import generate_daily_summary -from hub_trades_lib import current_trading_day -from settings_store import normalize_supervisor_settings - - -class ChatSendBody(BaseModel): - message: str = "" - trading_day: str = "" - - -class SummaryGenerateBody(BaseModel): - trading_day: str = "" - force: bool = False - - -class ChatNewBody(BaseModel): - trading_day: str = "" - bot_mode: str = "trading" - - -class ChatSwitchBody(BaseModel): - session_id: str = Field(..., min_length=1) - - -class ArchiveQuoteChatBody(BaseModel): - quote_date: str = "" - content: str = "" - - -class SupervisorChatBody(BaseModel): - message: str = "" - trading_day: str = "" - - -def create_hub_ai_router(*, load_all_exchanges: Callable[[], list]) -> APIRouter: - router = APIRouter(prefix="/api/ai", tags=["hub-ai"]) - - def _day(raw: str = "") -> str: - d = (raw or "").strip()[:10] - return d or current_trading_day(reset_hour=trading_day_reset_hour()) - - @router.get("/meta") - def api_ai_meta(): - return { - "ok": True, - "model": model_label(), - "trading_day_reset_hour": trading_day_reset_hour(), - "trading_day": current_trading_day(reset_hour=trading_day_reset_hour()), - "storage": { - "summaries": "hub_ai_summaries.json", - "chat": "hub_ai_chat.json", - }, - } - - @router.get("/context") - def api_ai_context(trading_day: str = ""): - exchanges = load_all_exchanges() - ctx = build_daily_context(exchanges, trading_day=_day(trading_day)) - return {"ok": True, **ctx} - - @router.get("/summary") - def api_ai_summary_list(trading_day: str = ""): - day = _day(trading_day) if trading_day.strip() else "" - items = list_summaries(trading_day=day or None, limit=20) - latest = get_latest_summary(_day(trading_day)) if trading_day.strip() else ( - items[0] if items else None - ) - return { - "ok": True, - "trading_day": _day(trading_day) if trading_day.strip() else None, - "summaries": items, - "latest": latest, - "model": model_label(), - } - - @router.post("/summary/generate") - def api_ai_summary_generate(body: SummaryGenerateBody = SummaryGenerateBody()): - exchanges = load_all_exchanges() - result = generate_daily_summary( - exchanges, - trading_day=_day(body.trading_day) if body.trading_day.strip() else None, - force=bool(body.force), - ) - if not result.get("ok"): - raise HTTPException(status_code=502, detail=result.get("msg") or "生成失败") - result.pop("context", None) - return result - - @router.get("/chat/session") - def api_ai_chat_session(): - state = get_chat_state() - return {"ok": True, **state, "model": model_label()} - - @router.post("/chat/new") - def api_ai_chat_new(body: ChatNewBody = ChatNewBody()): - day = _day(body.trading_day) - return start_new_chat(trading_day=day, bot_mode=body.bot_mode or "trading") - - @router.post("/chat/switch") - def api_ai_chat_switch(body: ChatSwitchBody): - try: - return switch_chat_session(body.session_id.strip()) - except KeyError: - raise HTTPException(status_code=404, detail="会话不存在") - - @router.delete("/chat/session/{session_id}") - def api_ai_chat_delete(session_id: str): - result = remove_chat_session(session_id.strip()) - if not result.get("ok"): - raise HTTPException(status_code=404, detail="会话不存在") - return result - - @router.post("/chat/archive-quote") - def api_ai_chat_archive_quote(body: ArchiveQuoteChatBody = Body(...)): - result = send_archive_quote_review( - quote_date=body.quote_date, - content=body.content, - ) - if not result.get("ok"): - raise HTTPException(status_code=502, detail=result.get("msg") or "发送失败") - return result - - @router.post("/chat/send") - async def api_ai_chat_send( - message: str = Form(""), - trading_day: str = Form(""), - files: list[UploadFile] = File(default=[]), - ): - exchanges = load_all_exchanges() - raw_attachments = [] - for f in files or []: - if not f or not f.filename: - continue - data = await f.read() - raw_attachments.append( - { - "filename": f.filename, - "content_type": f.content_type or "", - "data": data, - } - ) - result = await asyncio.to_thread( - send_chat_message, - exchanges, - message, - trading_day=_day(trading_day) if trading_day.strip() else None, - raw_attachments=raw_attachments, - ) - if not result.get("ok"): - raise HTTPException(status_code=502, detail=result.get("msg") or "发送失败") - return result - - @router.get("/supervisor/session") - def api_ai_supervisor_session(trading_day: str = ""): - day = _day(trading_day) - return get_supervisor_session_state(day) - - @router.get("/supervisor/rules") - def api_ai_supervisor_rules(): - from settings_store import load_settings - - cfg = normalize_supervisor_settings(load_settings().get("supervisor")) - return {"ok": True, "supervisor": cfg} - - @router.post("/supervisor/chat/send") - def api_ai_supervisor_chat_send(body: SupervisorChatBody = SupervisorChatBody()): - exchanges = load_all_exchanges() - result = send_supervisor_chat( - exchanges, - body.message, - trading_day=_day(body.trading_day) if body.trading_day.strip() else None, - ) - if not result.get("ok"): - raise HTTPException(status_code=502, detail=result.get("msg") or "发送失败") - return result - - return router +"""中控 AI FastAPI 路由。""" +from __future__ import annotations + +import asyncio +from typing import Callable + +from fastapi import APIRouter, Body, File, Form, HTTPException, UploadFile +from pydantic import BaseModel, Field + +from hub_ai.archive_quote import send_archive_quote_review +from hub_ai.chat import ( + get_chat_state, + remove_chat_session, + send_chat_message, + start_new_chat, + switch_chat_session, +) +from hub_ai.client import model_label +from hub_ai.config import trading_day_reset_hour +from hub_ai.context import build_daily_context +from hub_ai.store import get_latest_summary, list_summaries +from hub_ai.supervisor import send_supervisor_chat +from hub_ai.supervisor_store import get_supervisor_session_state +from hub_ai.summary import generate_daily_summary +from lib.hub.hub_trades_lib import current_trading_day +from settings_store import normalize_supervisor_settings + + +class ChatSendBody(BaseModel): + message: str = "" + trading_day: str = "" + + +class SummaryGenerateBody(BaseModel): + trading_day: str = "" + force: bool = False + + +class ChatNewBody(BaseModel): + trading_day: str = "" + bot_mode: str = "trading" + + +class ChatSwitchBody(BaseModel): + session_id: str = Field(..., min_length=1) + + +class ArchiveQuoteChatBody(BaseModel): + quote_date: str = "" + content: str = "" + + +class SupervisorChatBody(BaseModel): + message: str = "" + trading_day: str = "" + + +def create_hub_ai_router(*, load_all_exchanges: Callable[[], list]) -> APIRouter: + router = APIRouter(prefix="/api/ai", tags=["hub-ai"]) + + def _day(raw: str = "") -> str: + d = (raw or "").strip()[:10] + return d or current_trading_day(reset_hour=trading_day_reset_hour()) + + @router.get("/meta") + def api_ai_meta(): + return { + "ok": True, + "model": model_label(), + "trading_day_reset_hour": trading_day_reset_hour(), + "trading_day": current_trading_day(reset_hour=trading_day_reset_hour()), + "storage": { + "summaries": "hub_ai_summaries.json", + "chat": "hub_ai_chat.json", + }, + } + + @router.get("/context") + def api_ai_context(trading_day: str = ""): + exchanges = load_all_exchanges() + ctx = build_daily_context(exchanges, trading_day=_day(trading_day)) + return {"ok": True, **ctx} + + @router.get("/summary") + def api_ai_summary_list(trading_day: str = ""): + day = _day(trading_day) if trading_day.strip() else "" + items = list_summaries(trading_day=day or None, limit=20) + latest = get_latest_summary(_day(trading_day)) if trading_day.strip() else ( + items[0] if items else None + ) + return { + "ok": True, + "trading_day": _day(trading_day) if trading_day.strip() else None, + "summaries": items, + "latest": latest, + "model": model_label(), + } + + @router.post("/summary/generate") + def api_ai_summary_generate(body: SummaryGenerateBody = SummaryGenerateBody()): + exchanges = load_all_exchanges() + result = generate_daily_summary( + exchanges, + trading_day=_day(body.trading_day) if body.trading_day.strip() else None, + force=bool(body.force), + ) + if not result.get("ok"): + raise HTTPException(status_code=502, detail=result.get("msg") or "生成失败") + result.pop("context", None) + return result + + @router.get("/chat/session") + def api_ai_chat_session(): + state = get_chat_state() + return {"ok": True, **state, "model": model_label()} + + @router.post("/chat/new") + def api_ai_chat_new(body: ChatNewBody = ChatNewBody()): + day = _day(body.trading_day) + return start_new_chat(trading_day=day, bot_mode=body.bot_mode or "trading") + + @router.post("/chat/switch") + def api_ai_chat_switch(body: ChatSwitchBody): + try: + return switch_chat_session(body.session_id.strip()) + except KeyError: + raise HTTPException(status_code=404, detail="会话不存在") + + @router.delete("/chat/session/{session_id}") + def api_ai_chat_delete(session_id: str): + result = remove_chat_session(session_id.strip()) + if not result.get("ok"): + raise HTTPException(status_code=404, detail="会话不存在") + return result + + @router.post("/chat/archive-quote") + def api_ai_chat_archive_quote(body: ArchiveQuoteChatBody = Body(...)): + result = send_archive_quote_review( + quote_date=body.quote_date, + content=body.content, + ) + if not result.get("ok"): + raise HTTPException(status_code=502, detail=result.get("msg") or "发送失败") + return result + + @router.post("/chat/send") + async def api_ai_chat_send( + message: str = Form(""), + trading_day: str = Form(""), + files: list[UploadFile] = File(default=[]), + ): + exchanges = load_all_exchanges() + raw_attachments = [] + for f in files or []: + if not f or not f.filename: + continue + data = await f.read() + raw_attachments.append( + { + "filename": f.filename, + "content_type": f.content_type or "", + "data": data, + } + ) + result = await asyncio.to_thread( + send_chat_message, + exchanges, + message, + trading_day=_day(trading_day) if trading_day.strip() else None, + raw_attachments=raw_attachments, + ) + if not result.get("ok"): + raise HTTPException(status_code=502, detail=result.get("msg") or "发送失败") + return result + + @router.get("/supervisor/session") + def api_ai_supervisor_session(trading_day: str = ""): + day = _day(trading_day) + return get_supervisor_session_state(day) + + @router.get("/supervisor/rules") + def api_ai_supervisor_rules(): + from settings_store import load_settings + + cfg = normalize_supervisor_settings(load_settings().get("supervisor")) + return {"ok": True, "supervisor": cfg} + + @router.post("/supervisor/chat/send") + def api_ai_supervisor_chat_send(body: SupervisorChatBody = SupervisorChatBody()): + exchanges = load_all_exchanges() + result = send_supervisor_chat( + exchanges, + body.message, + trading_day=_day(body.trading_day) if body.trading_day.strip() else None, + ) + if not result.get("ok"): + raise HTTPException(status_code=502, detail=result.get("msg") or "发送失败") + return result + + return router diff --git a/manual_trading_hub/hub_ai/supervisor.py b/manual_trading_hub/hub_ai/supervisor.py index af42331..fbefb48 100644 --- a/manual_trading_hub/hub_ai/supervisor.py +++ b/manual_trading_hub/hub_ai/supervisor.py @@ -1,125 +1,125 @@ -"""交易监管:AI 评语与用户回聊。""" -from __future__ import annotations - -import sys -from pathlib import Path -from typing import Any, Optional - -_REPO_ROOT = Path(__file__).resolve().parents[2] -if str(_REPO_ROOT) not in sys.path: - sys.path.insert(0, str(_REPO_ROOT)) - -from ai_client import ai_generate # noqa: E402 - -from hub_ai.client import generate_text, model_label -from hub_ai.config import ( - CHAT_MAX_OUTPUT_TOKENS, - CHAT_TEMPERATURE, - trading_day_reset_hour, -) -from hub_ai.context import build_chat_context, format_chat_context_for_chat, format_chat_position_overview -from hub_ai.prompts import SUPERVISOR_SYSTEM, build_supervisor_ai_prompt, build_supervisor_chat_prompt -from hub_ai.supervisor_store import ( - append_supervisor_ai_message, - ensure_supervisor_session, - get_supervisor_session_state, -) -from hub_ai.store import append_chat_message -from hub_ai.text_util import is_ai_error_reply -from hub_supervisor_lib import build_supervisor_fallback_reply -from hub_trades_lib import current_trading_day - -SUPERVISOR_AI_MAX_TOKENS = 320 - - -def generate_supervisor_ai_reply( - *, - event: dict, - warnings: list[dict], - trading_day: str, - session_id: str, - exchanges: list[dict], -) -> str: - ctx = build_chat_context(exchanges, trading_day=trading_day) - brief = format_chat_position_overview(ctx) + "\n" + format_chat_context_for_chat( - ctx, max_chars=2400 - ) - user_prompt = build_supervisor_ai_prompt( - context_text=brief, - trading_day=trading_day, - event=event, - warnings=warnings, - ) - prompt = f"{SUPERVISOR_SYSTEM.strip()}\n\n---\n\n{user_prompt.strip()}" - text = ai_generate(prompt, temperature=0.35, max_tokens=SUPERVISOR_AI_MAX_TOKENS) - text = str(text or "").strip() - if not text or is_ai_error_reply(text): - return build_supervisor_fallback_reply(event, warnings) - return text - - -def make_supervisor_ai_reply_fn(exchanges: list[dict]): - def _fn(*, event: dict, warnings: list[dict], trading_day: str, session_id: str) -> str: - return generate_supervisor_ai_reply( - event=event, - warnings=warnings or [], - trading_day=trading_day, - session_id=session_id, - exchanges=exchanges, - ) - - return _fn - - -def send_supervisor_chat( - exchanges: list[dict], - message: str, - *, - trading_day: str | None = None, -) -> dict[str, Any]: - text = (message or "").strip() - if not text: - return {"ok": False, "msg": "消息不能为空"} - day = (trading_day or "").strip()[:10] or current_trading_day( - reset_hour=trading_day_reset_hour() - ) - session = ensure_supervisor_session(day) - sid = str(session.get("id") or "") - prior = session.get("messages") or [] - ctx = build_chat_context(exchanges, trading_day=day) - brief = format_chat_context_for_chat(ctx, max_chars=6000) - recent = [] - for m in prior[-8:]: - role = m.get("role") - if role not in ("user", "assistant", "system"): - continue - label = {"user": "用户", "assistant": "监管", "system": "系统"}.get(role, role) - recent.append(f"{label}:{str(m.get('content') or '').strip()}") - user_prompt = build_supervisor_chat_prompt( - context_text=brief, - trading_day=day, - history_lines="\n".join(recent), - user_message=text, - ) - reply = generate_text( - system=SUPERVISOR_SYSTEM, - user=user_prompt, - temperature=min(0.4, CHAT_TEMPERATURE), - max_tokens=min(768, CHAT_MAX_OUTPUT_TOKENS), - max_continuations=1, - ) - reply = str(reply or "").strip() - if not reply or is_ai_error_reply(reply): - return {"ok": False, "msg": "AI 暂时不可用,请稍后再试", "session_id": sid} - append_chat_message(sid, "user", text) - session = append_supervisor_ai_message(sid, reply) - state = get_supervisor_session_state(day) - return { - "ok": True, - "trading_day": day, - "session": session, - "reply": reply, - "model": model_label(), - "message_count": state.get("message_count"), - "unread_system": state.get("unread_system"), - } +"""交易监管:AI 评语与用户回聊。""" +from __future__ import annotations + +import sys +from pathlib import Path +from typing import Any, Optional + +_REPO_ROOT = Path(__file__).resolve().parents[2] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from lib.ai.ai_client import ai_generate # noqa: E402 + +from hub_ai.client import generate_text, model_label +from hub_ai.config import ( + CHAT_MAX_OUTPUT_TOKENS, + CHAT_TEMPERATURE, + trading_day_reset_hour, +) +from hub_ai.context import build_chat_context, format_chat_context_for_chat, format_chat_position_overview +from hub_ai.prompts import SUPERVISOR_SYSTEM, build_supervisor_ai_prompt, build_supervisor_chat_prompt +from hub_ai.supervisor_store import ( + append_supervisor_ai_message, + ensure_supervisor_session, + get_supervisor_session_state, +) +from hub_ai.store import append_chat_message +from hub_ai.text_util import is_ai_error_reply +from hub_supervisor_lib import build_supervisor_fallback_reply +from lib.hub.hub_trades_lib import current_trading_day + +SUPERVISOR_AI_MAX_TOKENS = 320 + + +def generate_supervisor_ai_reply( + *, + event: dict, + warnings: list[dict], + trading_day: str, + session_id: str, + exchanges: list[dict], +) -> str: + ctx = build_chat_context(exchanges, trading_day=trading_day) + brief = format_chat_position_overview(ctx) + "\n" + format_chat_context_for_chat( + ctx, max_chars=2400 + ) + user_prompt = build_supervisor_ai_prompt( + context_text=brief, + trading_day=trading_day, + event=event, + warnings=warnings, + ) + prompt = f"{SUPERVISOR_SYSTEM.strip()}\n\n---\n\n{user_prompt.strip()}" + text = ai_generate(prompt, temperature=0.35, max_tokens=SUPERVISOR_AI_MAX_TOKENS) + text = str(text or "").strip() + if not text or is_ai_error_reply(text): + return build_supervisor_fallback_reply(event, warnings) + return text + + +def make_supervisor_ai_reply_fn(exchanges: list[dict]): + def _fn(*, event: dict, warnings: list[dict], trading_day: str, session_id: str) -> str: + return generate_supervisor_ai_reply( + event=event, + warnings=warnings or [], + trading_day=trading_day, + session_id=session_id, + exchanges=exchanges, + ) + + return _fn + + +def send_supervisor_chat( + exchanges: list[dict], + message: str, + *, + trading_day: str | None = None, +) -> dict[str, Any]: + text = (message or "").strip() + if not text: + return {"ok": False, "msg": "消息不能为空"} + day = (trading_day or "").strip()[:10] or current_trading_day( + reset_hour=trading_day_reset_hour() + ) + session = ensure_supervisor_session(day) + sid = str(session.get("id") or "") + prior = session.get("messages") or [] + ctx = build_chat_context(exchanges, trading_day=day) + brief = format_chat_context_for_chat(ctx, max_chars=6000) + recent = [] + for m in prior[-8:]: + role = m.get("role") + if role not in ("user", "assistant", "system"): + continue + label = {"user": "用户", "assistant": "监管", "system": "系统"}.get(role, role) + recent.append(f"{label}:{str(m.get('content') or '').strip()}") + user_prompt = build_supervisor_chat_prompt( + context_text=brief, + trading_day=day, + history_lines="\n".join(recent), + user_message=text, + ) + reply = generate_text( + system=SUPERVISOR_SYSTEM, + user=user_prompt, + temperature=min(0.4, CHAT_TEMPERATURE), + max_tokens=min(768, CHAT_MAX_OUTPUT_TOKENS), + max_continuations=1, + ) + reply = str(reply or "").strip() + if not reply or is_ai_error_reply(reply): + return {"ok": False, "msg": "AI 暂时不可用,请稍后再试", "session_id": sid} + append_chat_message(sid, "user", text) + session = append_supervisor_ai_message(sid, reply) + state = get_supervisor_session_state(day) + return { + "ok": True, + "trading_day": day, + "session": session, + "reply": reply, + "model": model_label(), + "message_count": state.get("message_count"), + "unread_system": state.get("unread_system"), + } diff --git a/manual_trading_hub/hub_dashboard.py b/manual_trading_hub/hub_dashboard.py index d0e752f..0b5c940 100644 --- a/manual_trading_hub/hub_dashboard.py +++ b/manual_trading_hub/hub_dashboard.py @@ -1,107 +1,107 @@ -"""中控数据看板:四户当日总览(无 AI,纯数据聚合)。""" -from __future__ import annotations - -from datetime import datetime, timezone -from typing import Any, Optional - -from hub_ai.context import ( - build_daily_context, - collect_closed_trades_snapshot, - format_account_remark, - format_dashboard_account_detail, -) -from hub_ai.config import trading_day_reset_hour -from hub_trades_lib import current_trading_day - -LOSS_ALERT_PCT = 5.0 -DASHBOARD_POLL_INTERVAL_SEC = 60 - - -def _safe_float(v: Any) -> Optional[float]: - try: - if v is None or v == "": - return None - return float(v) - except (TypeError, ValueError): - return None - - -def _account_capital_base(ac: dict) -> Optional[float]: - funding = _safe_float(ac.get("funding_usdt")) - trading = _safe_float(ac.get("trading_usdt")) - if funding is not None and trading is not None: - return funding + trading - if funding is not None: - return funding - if trading is not None: - return trading - return None - - -def _enrich_account_row(ac: dict) -> dict: - st = ac.get("trade_stats") or {} - capital = _account_capital_base(ac) - day_pnl = float(st.get("total_pnl_u") or 0) - loss_pct: Optional[float] = None - loss_alert = False - if capital is not None and capital > 0 and day_pnl < -1e-9: - loss_pct = round(abs(day_pnl) / capital * 100.0, 2) - loss_alert = loss_pct >= LOSS_ALERT_PCT - return { - "id": ac.get("id"), - "key": ac.get("key"), - "name": ac.get("name"), - "status": ac.get("status"), - "monitored": ac.get("status") != "未监控", - "funding_usdt": ac.get("funding_usdt"), - "trading_usdt": ac.get("trading_usdt"), - "capital_total_usdt": round(capital, 4) if capital is not None else None, - "available_trading_usdt": ac.get("available_trading_usdt"), - "pnl_u": st.get("total_pnl_u"), - "closed_count": st.get("closed_count"), - "win_count": st.get("win_count"), - "loss_count": st.get("loss_count"), - "float_pnl_u": ac.get("float_pnl_u"), - "open_position_count": ac.get("open_position_count"), - "remark": format_account_remark(ac), - **format_dashboard_account_detail(ac), - "issues": ac.get("issues") or [], - "daily_loss_pct": loss_pct, - "loss_alert": loss_alert, - } - - -def build_dashboard_payload( - exchanges: list[dict], - *, - trading_day: str | None = None, -) -> dict[str, Any]: - ctx = build_daily_context(exchanges, trading_day=trading_day) - day = ctx["trading_day"] - accounts_raw = ctx.get("accounts") or [] - accounts = [ - _enrich_account_row(ac) - for ac in accounts_raw - if ac.get("status") != "未监控" - ] - closed_trades = collect_closed_trades_snapshot( - [ac for ac in accounts_raw if ac.get("status") != "未监控"], - today=day, - ) - loss_alert_count = sum(1 for ac in accounts if ac.get("loss_alert")) - now = datetime.now(timezone.utc).astimezone().strftime("%Y-%m-%d %H:%M:%S") - return { - "ok": True, - "updated_at": now, - "trading_day": day, - "totals": ctx.get("totals"), - "accounts": accounts, - "closed_trades": closed_trades, - "loss_alert_pct_threshold": LOSS_ALERT_PCT, - "loss_alert_count": loss_alert_count, - "poll_interval_sec": DASHBOARD_POLL_INTERVAL_SEC, - } - - -def default_trading_day() -> str: - return current_trading_day(reset_hour=trading_day_reset_hour()) +"""中控数据看板:四户当日总览(无 AI,纯数据聚合)。""" +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any, Optional + +from hub_ai.context import ( + build_daily_context, + collect_closed_trades_snapshot, + format_account_remark, + format_dashboard_account_detail, +) +from hub_ai.config import trading_day_reset_hour +from lib.hub.hub_trades_lib import current_trading_day + +LOSS_ALERT_PCT = 5.0 +DASHBOARD_POLL_INTERVAL_SEC = 60 + + +def _safe_float(v: Any) -> Optional[float]: + try: + if v is None or v == "": + return None + return float(v) + except (TypeError, ValueError): + return None + + +def _account_capital_base(ac: dict) -> Optional[float]: + funding = _safe_float(ac.get("funding_usdt")) + trading = _safe_float(ac.get("trading_usdt")) + if funding is not None and trading is not None: + return funding + trading + if funding is not None: + return funding + if trading is not None: + return trading + return None + + +def _enrich_account_row(ac: dict) -> dict: + st = ac.get("trade_stats") or {} + capital = _account_capital_base(ac) + day_pnl = float(st.get("total_pnl_u") or 0) + loss_pct: Optional[float] = None + loss_alert = False + if capital is not None and capital > 0 and day_pnl < -1e-9: + loss_pct = round(abs(day_pnl) / capital * 100.0, 2) + loss_alert = loss_pct >= LOSS_ALERT_PCT + return { + "id": ac.get("id"), + "key": ac.get("key"), + "name": ac.get("name"), + "status": ac.get("status"), + "monitored": ac.get("status") != "未监控", + "funding_usdt": ac.get("funding_usdt"), + "trading_usdt": ac.get("trading_usdt"), + "capital_total_usdt": round(capital, 4) if capital is not None else None, + "available_trading_usdt": ac.get("available_trading_usdt"), + "pnl_u": st.get("total_pnl_u"), + "closed_count": st.get("closed_count"), + "win_count": st.get("win_count"), + "loss_count": st.get("loss_count"), + "float_pnl_u": ac.get("float_pnl_u"), + "open_position_count": ac.get("open_position_count"), + "remark": format_account_remark(ac), + **format_dashboard_account_detail(ac), + "issues": ac.get("issues") or [], + "daily_loss_pct": loss_pct, + "loss_alert": loss_alert, + } + + +def build_dashboard_payload( + exchanges: list[dict], + *, + trading_day: str | None = None, +) -> dict[str, Any]: + ctx = build_daily_context(exchanges, trading_day=trading_day) + day = ctx["trading_day"] + accounts_raw = ctx.get("accounts") or [] + accounts = [ + _enrich_account_row(ac) + for ac in accounts_raw + if ac.get("status") != "未监控" + ] + closed_trades = collect_closed_trades_snapshot( + [ac for ac in accounts_raw if ac.get("status") != "未监控"], + today=day, + ) + loss_alert_count = sum(1 for ac in accounts if ac.get("loss_alert")) + now = datetime.now(timezone.utc).astimezone().strftime("%Y-%m-%d %H:%M:%S") + return { + "ok": True, + "updated_at": now, + "trading_day": day, + "totals": ctx.get("totals"), + "accounts": accounts, + "closed_trades": closed_trades, + "loss_alert_pct_threshold": LOSS_ALERT_PCT, + "loss_alert_count": loss_alert_count, + "poll_interval_sec": DASHBOARD_POLL_INTERVAL_SEC, + } + + +def default_trading_day() -> str: + return current_trading_day(reset_hour=trading_day_reset_hour()) diff --git a/manual_trading_hub/hub_supervisor_lib.py b/manual_trading_hub/hub_supervisor_lib.py index 660c6f7..01dcf69 100644 --- a/manual_trading_hub/hub_supervisor_lib.py +++ b/manual_trading_hub/hub_supervisor_lib.py @@ -1,757 +1,757 @@ -"""交易监管:事件分类、频率规则、会话消息与企业微信推送。""" -from __future__ import annotations - -import json -import os -import threading -import uuid -from datetime import datetime, timedelta -from pathlib import Path -from typing import Any, Callable, Optional - -from hub_trades_lib import current_trading_day, parse_dt_for_trading_day - -HUB_DIR = Path(__file__).resolve().parent -STATE_PATH = HUB_DIR / "hub_supervisor_state.json" - -PROGRAM_RESULTS = frozenset({"止盈", "止损", "保本止盈", "移动止盈"}) -MANUAL_CLOSE_RESULTS = frozenset({"手动平仓"}) -HUB_CLOSE_RESULTS = frozenset({"强制清仓"}) -WEAK_RESULTS = frozenset({"外部平仓", "时间平仓"}) - -EVENT_OPEN = "open" -EVENT_MANUAL_CLOSE = "manual_close" -EVENT_HUB_CLOSE = "hub_close" -EVENT_PROGRAM_TP = "program_tp" -EVENT_PROGRAM_SL = "program_sl" -EVENT_EXTERNAL = "external" -EVENT_FREQ_WARN = "freq_warn" - -DEFAULT_SUPERVISOR = { - "enabled": True, - "wechat_webhook": "", - "wechat_link_base": "http://127.0.0.1:5100/ai?mode=supervisor", - "wechat_prefix": "【交易监管】", - "wechat_on_program_tp_sl": True, - "manual_close_daily_warn": 2, - "interval_warn_minutes": 15, - "freq_30m_count": 2, - "reopen_after_close_minutes": 30, -} - - -def _now_str() -> str: - return datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - -def _atomic_write(path: Path, data: dict) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - tmp = path.with_suffix(path.suffix + ".tmp") - tmp.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") - os.replace(tmp, path) - - -def _load_json(path: Path, default: dict) -> dict: - if not path.is_file(): - return dict(default) - try: - loaded = json.loads(path.read_text(encoding="utf-8")) - if isinstance(loaded, dict): - return loaded - except Exception: - pass - return dict(default) - - -def normalize_supervisor_settings(raw: dict | None) -> dict: - out = dict(DEFAULT_SUPERVISOR) - env_webhook = (os.getenv("SUPERVISOR_WECHAT_WEBHOOK") or "").strip() - env_link = (os.getenv("SUPERVISOR_WECHAT_LINK") or "").strip() - if env_webhook: - out["wechat_webhook"] = env_webhook - if env_link: - out["wechat_link_base"] = env_link - if not isinstance(raw, dict): - return out - for key in DEFAULT_SUPERVISOR: - if key not in raw: - continue - val = raw.get(key) - if key == "enabled" or key == "wechat_on_program_tp_sl": - out[key] = bool(val) - elif key in ("manual_close_daily_warn", "freq_30m_count"): - try: - out[key] = max(1, int(val)) - except (TypeError, ValueError): - pass - elif key in ("interval_warn_minutes", "reopen_after_close_minutes"): - try: - out[key] = max(1, int(val)) - except (TypeError, ValueError): - pass - elif isinstance(val, str): - out[key] = val.strip() - return out - - -def load_supervisor_state() -> dict: - data = _load_json(STATE_PATH, {"version": 1, "trading_day": "", "processed": [], "positions": {}, "stats": {}}) - data.setdefault("version", 1) - data.setdefault("processed", []) - data.setdefault("positions", {}) - data.setdefault("stats", {}) - return data - - -def save_supervisor_state(data: dict) -> None: - processed = list(data.get("processed") or []) - if len(processed) > 500: - processed = processed[-500:] - data["processed"] = processed - _atomic_write(STATE_PATH, data) - - -def _trade_event_id(trade: dict) -> str: - return "|".join( - [ - str(trade.get("account_name") or trade.get("account_key") or ""), - str(trade.get("symbol") or ""), - str(trade.get("closed_at") or ""), - str(trade.get("result") or ""), - str(trade.get("pnl_amount") or ""), - ] - ) - - -def classify_close_result(result: str) -> str: - r = (result or "").strip() - if r in PROGRAM_RESULTS: - if r == "止损": - return EVENT_PROGRAM_SL - return EVENT_PROGRAM_TP - if r in MANUAL_CLOSE_RESULTS: - return EVENT_MANUAL_CLOSE - if r in HUB_CLOSE_RESULTS: - return EVENT_HUB_CLOSE - if r in WEAK_RESULTS: - return EVENT_EXTERNAL - return EVENT_EXTERNAL - - -def is_supervised_event(event_type: str) -> bool: - return event_type in (EVENT_OPEN, EVENT_MANUAL_CLOSE, EVENT_HUB_CLOSE) - - -def is_program_event(event_type: str) -> bool: - return event_type in (EVENT_PROGRAM_TP, EVENT_PROGRAM_SL) - - -def _normalize_position_symbol(sym: str) -> str: - """统一合约名,避免 ZEC/USDT 与 ZEC/USDT:USDT 被当成两笔持仓。""" - s = (sym or "").strip().upper() - if not s: - return "" - if s.endswith(":USDT") and "/" in s: - return s.rsplit(":", 1)[0] - return s - - -def _position_key(exchange_id: str, symbol: str, side: str) -> str: - sym = _normalize_position_symbol(symbol) - sd = (side or "long").strip().lower() or "long" - return f"{exchange_id}|{sym}|{sd}" - - -def _position_contracts(pos: dict) -> float: - for key in ("contracts", "contracts_signed", "size"): - try: - v = pos.get(key) - if v is not None and v != "": - return abs(float(v)) - except (TypeError, ValueError): - continue - return 0.0 - - -def collect_position_keys(board_payload: dict | None) -> dict[str, dict]: - out: dict[str, dict] = {} - rows = (board_payload or {}).get("rows") or [] - for row in rows: - if not isinstance(row, dict): - continue - ex_id = str(row.get("id") or row.get("key") or "") - ex_name = str(row.get("name") or row.get("key") or ex_id) - ag = row.get("agent") or {} - for p in ag.get("positions") or []: - if not isinstance(p, dict): - continue - if _position_contracts(p) < 1e-12: - continue - sym = str(p.get("symbol") or "") - side = str(p.get("side") or "").lower() or "long" - key = _position_key(ex_id, sym, side) - out[key] = { - "exchange_id": ex_id, - "exchange_name": ex_name, - "symbol": sym, - "side": side, - "contracts": _position_contracts(p), - } - return out - - -def _board_agent_snapshot_ready(board_payload: dict | None) -> bool: - """监控板各启用账户 agent 快照已就绪(避免空板先入库导致后续持仓误判为新开)。""" - if not isinstance(board_payload, dict) or board_payload.get("ok") is False: - return False - rows = board_payload.get("rows") or [] - if not rows: - return False - seen = 0 - for row in rows: - if not isinstance(row, dict): - continue - if row.get("enabled") is False: - continue - ag = row.get("agent") - if not isinstance(ag, dict): - return False - seen += 1 - return seen > 0 - - -def _entry_contracts(entry: dict | None) -> float: - if not isinstance(entry, dict): - return 0.0 - try: - return float(entry.get("contracts") or 0) - except (TypeError, ValueError): - return 0.0 - - -def detect_new_opens( - prev_positions: dict[str, dict], - curr_positions: dict[str, dict], -) -> list[dict]: - """仅当某合约从空仓变为有仓时视为新开(已有持仓不加仓不算)。""" - events = [] - for key, info in curr_positions.items(): - curr_c = _entry_contracts(info) - if curr_c < 1e-12: - continue - prev_c = _entry_contracts(prev_positions.get(key)) - if prev_c >= 1e-12: - continue - events.append({"event_type": EVENT_OPEN, "event_id": f"open:{key}:{_now_str()[:16]}", **info}) - return events - - -def detect_new_closes( - prev_processed: set[str], - closed_trades: list[dict], -) -> list[dict]: - events = [] - for trade in closed_trades or []: - if not isinstance(trade, dict): - continue - eid = _trade_event_id(trade) - if eid in prev_processed: - continue - event_type = classify_close_result(str(trade.get("result") or "")) - events.append( - { - "event_type": event_type, - "event_id": f"close:{eid}", - "account_name": trade.get("account_name"), - "symbol": trade.get("symbol"), - "direction": trade.get("direction"), - "result": trade.get("result"), - "pnl_amount": trade.get("pnl_amount"), - "closed_at": trade.get("closed_at"), - } - ) - return events - - -def _parse_event_dt(raw: Any) -> Optional[datetime]: - return parse_dt_for_trading_day(raw) - - -def _supervised_close_times(stats: dict, trading_day: str) -> list[datetime]: - rows = (stats.get(trading_day) or {}).get("supervised_closes") or [] - out = [] - for item in rows: - if isinstance(item, dict): - dt = _parse_event_dt(item.get("closed_at") or item.get("at")) - else: - dt = _parse_event_dt(item) - if dt: - out.append(dt) - out.sort() - return out - - -def _record_supervised_event(stats: dict, trading_day: str, event: dict) -> None: - day_stats = stats.setdefault(trading_day, {}) - et = str(event.get("event_type") or "") - if et == EVENT_OPEN: - opens = list(day_stats.get("supervised_opens") or []) - opens.append({"at": _now_str(), "symbol": event.get("symbol")}) - day_stats["supervised_opens"] = opens[-50:] - return - if et not in (EVENT_MANUAL_CLOSE, EVENT_HUB_CLOSE): - return - closes = list(day_stats.get("supervised_closes") or []) - closes.append( - { - "at": _now_str(), - "closed_at": event.get("closed_at"), - "event_type": et, - "pnl_amount": event.get("pnl_amount"), - } - ) - day_stats["supervised_closes"] = closes[-50:] - - -def evaluate_frequency_warnings( - *, - trading_day: str, - event: dict, - stats: dict, - settings: dict, -) -> list[dict]: - if not is_supervised_event(str(event.get("event_type") or "")): - return [] - warnings: list[dict] = [] - day_stats = stats.setdefault(trading_day, {}) - closes = _supervised_close_times(stats, trading_day) - now = datetime.now() - if event.get("event_type") in (EVENT_MANUAL_CLOSE, EVENT_HUB_CLOSE): - evt_dt = _parse_event_dt(event.get("closed_at")) or now - closes = closes + [evt_dt] - closes.sort() - open_count = len(day_stats.get("supervised_opens") or []) - close_count = len(day_stats.get("supervised_closes") or []) - if event.get("event_type") == EVENT_OPEN: - open_count += 1 - elif event.get("event_type") in (EVENT_MANUAL_CLOSE, EVENT_HUB_CLOSE): - close_count += 1 - - interval_min = int(settings.get("interval_warn_minutes") or 15) - daily_warn = int(settings.get("manual_close_daily_warn") or 2) - freq_30m = int(settings.get("freq_30m_count") or 2) - reopen_min = int(settings.get("reopen_after_close_minutes") or 30) - - if event.get("event_type") in (EVENT_MANUAL_CLOSE, EVENT_HUB_CLOSE) and len(closes) >= 2: - prev = closes[-2] - cur = closes[-1] - gap = (cur - prev).total_seconds() / 60.0 - if gap < interval_min: - warnings.append( - { - "rule": "INTERVAL_SHORT", - "message": f"两笔手动/中控平间隔仅 {int(gap)} 分钟(阈值 {interval_min} 分钟)", - } - ) - - recent_closes = [t for t in closes if (now - t).total_seconds() <= 30 * 60] - if event.get("event_type") in (EVENT_MANUAL_CLOSE, EVENT_HUB_CLOSE) and len(recent_closes) >= freq_30m: - warnings.append( - { - "rule": "FREQ_30M", - "message": f"30 分钟内手动/中控平已达 {len(recent_closes)} 笔(阈值 {freq_30m} 笔)", - } - ) - - supervised_total = open_count + close_count - if supervised_total >= daily_warn and event.get("event_type") in ( - EVENT_MANUAL_CLOSE, - EVENT_HUB_CLOSE, - EVENT_OPEN, - ): - if close_count >= daily_warn: - warnings.append( - { - "rule": "DAILY_COUNT", - "message": f"今日手动/中控平 {close_count} 笔(阈值 {daily_warn} 笔),注意过度交易", - } - ) - - if event.get("event_type") == EVENT_OPEN and closes: - last_close = closes[-1] - gap_open = (now - last_close).total_seconds() / 60.0 - if gap_open < reopen_min: - warnings.append( - { - "rule": "REOPEN_FAST", - "message": f"距上一笔手动/中控平仅 {int(gap_open)} 分钟又新开仓(阈值 {reopen_min} 分钟)", - } - ) - - loss_streak = 0 - for item in reversed((stats.get(trading_day) or {}).get("supervised_closes") or []): - try: - pnl = float((item or {}).get("pnl_amount") or 0) - except (TypeError, ValueError): - pnl = 0.0 - if pnl < 0: - loss_streak += 1 - else: - break - if event.get("event_type") in (EVENT_MANUAL_CLOSE, EVENT_HUB_CLOSE): - try: - pnl = float(event.get("pnl_amount") or 0) - except (TypeError, ValueError): - pnl = 0.0 - if pnl < 0: - loss_streak += 1 - else: - loss_streak = 0 - if loss_streak >= 2 and event.get("event_type") in (EVENT_MANUAL_CLOSE, EVENT_HUB_CLOSE): - warnings.append( - { - "rule": "LOSS_STREAK", - "message": f"连续 {loss_streak} 笔手动/中控亏损,先停一停", - } - ) - - deduped = [] - seen = set() - for w in warnings: - key = w.get("rule") - if key in seen: - continue - seen.add(key) - deduped.append(w) - return deduped - - -def event_tag(event_type: str) -> str: - return { - EVENT_OPEN: "监管·开仓", - EVENT_MANUAL_CLOSE: "监管·手动平", - EVENT_HUB_CLOSE: "监管·中控平", - EVENT_PROGRAM_TP: "监管·程序止盈", - EVENT_PROGRAM_SL: "监管·程序止损", - EVENT_EXTERNAL: "监管·外部平", - EVENT_FREQ_WARN: "监管·频率", - }.get(event_type, "监管") - - -def _fmt_pnl_u(pnl: Any) -> str: - try: - v = float(pnl) - sign = "+" if v > 0 else "" - return f"{sign}{v:.4f}".rstrip("0").rstrip(".") + "U" - except (TypeError, ValueError): - return "" - - -def build_supervisor_fallback_reply(event: dict, warnings: list[dict] | None = None) -> str: - """AI 不可用或返回空时的短评语(不展示错误文案)。""" - et = str(event.get("event_type") or "") - sym = str(event.get("symbol") or "—") - ex = str(event.get("exchange_name") or event.get("account_name") or "").strip() - pnl_txt = _fmt_pnl_u(event.get("pnl_amount")) - warn = (warnings or [])[:1] - warn_txt = str(warn[0].get("message") or "").strip() if warn else "" - - if et == EVENT_PROGRAM_SL: - base = f"{sym} 程序止损" - if pnl_txt: - base += f"({pnl_txt})" - base += ",按计划出场是纪律。先歇一会儿,别急着马上再开。" - elif et == EVENT_PROGRAM_TP: - base = f"{sym} 程序止盈" - if pnl_txt: - base += f"({pnl_txt})" - base += ",执行不错。保持节奏,别立刻反手再开一单。" - elif et == EVENT_OPEN: - who = f"{ex} " if ex else "" - base = f"看到 {who}新开 {sym}。动手前确认是不是计划内,别因为上一笔情绪再开。" - elif et == EVENT_HUB_CLOSE: - base = f"中控平了 {sym}" - if pnl_txt: - base += f"({pnl_txt})" - base += "。" - base += f" {warn_txt}" if warn_txt else " 停一停,别连着手痒。" - elif et == EVENT_MANUAL_CLOSE: - base = f"手动平了 {sym}" - if pnl_txt: - base += f"({pnl_txt})" - base += "。" - base += f" {warn_txt}" if warn_txt else " 想好再开下一单。" - elif et == EVENT_FREQ_WARN: - base = warn_txt or "今日操作偏频繁,先休息一会儿。" - else: - base = "收到。确认是否按计划执行,别连续加码。" - return base.strip()[:320] - - -def build_system_message(event: dict, *, trading_day: str, warnings: list[dict] | None = None) -> str: - tag = event_tag(str(event.get("event_type") or "")) - ex = event.get("exchange_name") or event.get("account_name") or "—" - sym = event.get("symbol") or "—" - lines = [f"[{tag}] {ex} · {sym}"] - et = event.get("event_type") - if et == EVENT_OPEN: - side = event.get("side") or event.get("direction") or "" - if side: - lines.append(f"方向:{side}") - elif et in (EVENT_MANUAL_CLOSE, EVENT_HUB_CLOSE, EVENT_PROGRAM_TP, EVENT_PROGRAM_SL, EVENT_EXTERNAL): - res = event.get("result") or "" - pnl = event.get("pnl_amount") - if pnl is not None: - lines.append(f"结果 {res} · 盈亏 {pnl}U") - else: - lines.append(f"结果 {res}") - if event.get("closed_at"): - lines.append(f"平仓时间 {event.get('closed_at')}") - for w in warnings or []: - lines.append(f"⚠ {w.get('message')}") - lines.append(f"交易日 {trading_day}") - return "\n".join(lines) - - -def build_wechat_body( - event: dict, - *, - trading_day: str, - link_base: str, - system_text: str, -) -> str: - link = (link_base or "").strip() - if link: - sep = "&" if "?" in link else "?" - link = f"{link}{sep}day={trading_day}" - body = system_text.replace("\n", "\n") - if link: - body += f"\n详情:{link}" - return body - - -def should_send_wechat(event: dict, settings: dict) -> bool: - if not settings.get("enabled", True): - return False - webhook = (settings.get("wechat_webhook") or "").strip() - if not webhook or "replace-me" in webhook.lower(): - return False - et = str(event.get("event_type") or "") - if is_program_event(et): - return bool(settings.get("wechat_on_program_tp_sl", True)) - if et == EVENT_EXTERNAL: - return False - return True - - -def send_supervisor_wechat( - event: dict, - *, - trading_day: str, - settings: dict, - system_text: str, -) -> bool: - if not should_send_wechat(event, settings): - return False - from wechat_notify_lib import send_wechat_webhook - - prefix = (settings.get("wechat_prefix") or "【交易监管】").strip() - body = build_wechat_body( - event, - trading_day=trading_day, - link_base=str(settings.get("wechat_link_base") or ""), - system_text=system_text, - ) - return bool( - send_wechat_webhook( - str(settings.get("wechat_webhook") or ""), - body, - prefix=prefix, - ) - ) - - -_notify_hook: Optional[Callable[[], None]] = None - - -def set_supervisor_notify_hook(fn: Optional[Callable[[], None]]) -> None: - global _notify_hook - _notify_hook = fn - - -def _fire_notify() -> None: - if _notify_hook: - try: - _notify_hook() - except Exception: - pass - - -def process_supervisor_tick( - dashboard_payload: dict | None, - board_payload: dict | None, - settings_root: dict | None, - *, - reset_hour: int = 8, - ai_reply_fn: Optional[Callable[..., str]] = None, -) -> dict[str, Any]: - """单次监管扫描:对比快照、写会话、推微信、可选 AI 评语。""" - from hub_ai.supervisor_store import ( - append_supervisor_ai_message, - append_supervisor_system_message, - ensure_supervisor_session, - ) - - sup_cfg = normalize_supervisor_settings((settings_root or {}).get("supervisor")) - if not sup_cfg.get("enabled", True): - return {"ok": True, "skipped": True, "reason": "disabled"} - - dash = dashboard_payload or {} - trading_day = str(dash.get("trading_day") or current_trading_day(reset_hour=reset_hour)) - state = load_supervisor_state() - if str(state.get("trading_day") or "") != trading_day: - state = { - "version": 1, - "trading_day": trading_day, - "processed": [], - "positions": {}, - "stats": {trading_day: state.get("stats", {}).get(trading_day, {})}, - "positions_baseline_ready": False, - } - - processed = set(str(x) for x in (state.get("processed") or [])) - stats = dict(state.get("stats") or {}) - prev_positions = dict(state.get("positions") or {}) - curr_positions = collect_position_keys(board_payload) - closed_trades = dash.get("closed_trades") or [] - board_ready = _board_agent_snapshot_ready(board_payload) - - if not state.get("positions_baseline_ready"): - for trade in closed_trades: - if isinstance(trade, dict): - processed.add(f"close:{_trade_event_id(trade)}") - if not board_ready: - state["trading_day"] = trading_day - state["processed"] = list(processed) - save_supervisor_state(state) - return {"ok": True, "events": 0, "waiting_board": True, "trading_day": trading_day} - state["trading_day"] = trading_day - state["processed"] = list(processed) - state["positions"] = curr_positions - state["positions_baseline_ready"] = True - state["initialized"] = True - save_supervisor_state(state) - return { - "ok": True, - "events": 0, - "seeded": True, - "trading_day": trading_day, - "positions": len(curr_positions), - } - - raw_events = detect_new_opens(prev_positions, curr_positions) + detect_new_closes( - processed, closed_trades - ) - if not raw_events: - state["positions"] = curr_positions - save_supervisor_state(state) - return {"ok": True, "events": 0} - - session = ensure_supervisor_session(trading_day) - session_id = str(session.get("id") or "") - handled = 0 - - for event in raw_events: - eid = str(event.get("event_id") or uuid.uuid4().hex) - if eid in processed: - continue - et = str(event.get("event_type") or "") - if et == EVENT_EXTERNAL: - processed.add(eid) - continue - - warnings = evaluate_frequency_warnings( - trading_day=trading_day, - event=event, - stats=stats, - settings=sup_cfg, - ) - if is_supervised_event(et): - _record_supervised_event(stats, trading_day, event) - - system_text = build_system_message(event, trading_day=trading_day, warnings=warnings) - append_supervisor_system_message( - session_id, - system_text, - event_type=et, - level="warn" if warnings else "info", - ) - send_supervisor_wechat( - event, - trading_day=trading_day, - settings=sup_cfg, - system_text=system_text, - ) - for w in warnings: - warn_event = { - "event_type": EVENT_FREQ_WARN, - "event_id": f"warn:{eid}:{w.get('rule')}", - **event, - "warn_message": w.get("message"), - } - warn_text = f"[{event_tag(EVENT_FREQ_WARN)}] {w.get('message')}" - append_supervisor_system_message( - session_id, - warn_text, - event_type=EVENT_FREQ_WARN, - level="warn", - ) - send_supervisor_wechat( - warn_event, - trading_day=trading_day, - settings=sup_cfg, - system_text=warn_text, - ) - - if ai_reply_fn and et != EVENT_EXTERNAL: - evt_snapshot = dict(event) - evt_warnings = list(warnings) - - def _ai_bg() -> None: - try: - reply = ai_reply_fn( - event=evt_snapshot, - warnings=evt_warnings, - trading_day=trading_day, - session_id=session_id, - ) - from hub_ai.text_util import is_ai_error_reply - - text = str(reply or "").strip() - if not text or is_ai_error_reply(text): - text = build_supervisor_fallback_reply(evt_snapshot, evt_warnings) - if text: - append_supervisor_ai_message(session_id, text) - _fire_notify() - except Exception: - try: - fb = build_supervisor_fallback_reply(evt_snapshot, evt_warnings) - if fb: - append_supervisor_ai_message(session_id, fb) - _fire_notify() - except Exception: - pass - - threading.Thread(target=_ai_bg, daemon=True).start() - - processed.add(eid) - handled += 1 - - state["trading_day"] = trading_day - state["processed"] = list(processed) - state["positions"] = curr_positions - state["stats"] = stats - save_supervisor_state(state) - if handled: - _fire_notify() - return {"ok": True, "events": handled, "trading_day": trading_day, "session_id": session_id} +"""交易监管:事件分类、频率规则、会话消息与企业微信推送。""" +from __future__ import annotations + +import json +import os +import threading +import uuid +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, Callable, Optional + +from lib.hub.hub_trades_lib import current_trading_day, parse_dt_for_trading_day + +HUB_DIR = Path(__file__).resolve().parent +STATE_PATH = HUB_DIR / "hub_supervisor_state.json" + +PROGRAM_RESULTS = frozenset({"止盈", "止损", "保本止盈", "移动止盈"}) +MANUAL_CLOSE_RESULTS = frozenset({"手动平仓"}) +HUB_CLOSE_RESULTS = frozenset({"强制清仓"}) +WEAK_RESULTS = frozenset({"外部平仓", "时间平仓"}) + +EVENT_OPEN = "open" +EVENT_MANUAL_CLOSE = "manual_close" +EVENT_HUB_CLOSE = "hub_close" +EVENT_PROGRAM_TP = "program_tp" +EVENT_PROGRAM_SL = "program_sl" +EVENT_EXTERNAL = "external" +EVENT_FREQ_WARN = "freq_warn" + +DEFAULT_SUPERVISOR = { + "enabled": True, + "wechat_webhook": "", + "wechat_link_base": "http://127.0.0.1:5100/ai?mode=supervisor", + "wechat_prefix": "【交易监管】", + "wechat_on_program_tp_sl": True, + "manual_close_daily_warn": 2, + "interval_warn_minutes": 15, + "freq_30m_count": 2, + "reopen_after_close_minutes": 30, +} + + +def _now_str() -> str: + return datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def _atomic_write(path: Path, data: dict) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + tmp.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp, path) + + +def _load_json(path: Path, default: dict) -> dict: + if not path.is_file(): + return dict(default) + try: + loaded = json.loads(path.read_text(encoding="utf-8")) + if isinstance(loaded, dict): + return loaded + except Exception: + pass + return dict(default) + + +def normalize_supervisor_settings(raw: dict | None) -> dict: + out = dict(DEFAULT_SUPERVISOR) + env_webhook = (os.getenv("SUPERVISOR_WECHAT_WEBHOOK") or "").strip() + env_link = (os.getenv("SUPERVISOR_WECHAT_LINK") or "").strip() + if env_webhook: + out["wechat_webhook"] = env_webhook + if env_link: + out["wechat_link_base"] = env_link + if not isinstance(raw, dict): + return out + for key in DEFAULT_SUPERVISOR: + if key not in raw: + continue + val = raw.get(key) + if key == "enabled" or key == "wechat_on_program_tp_sl": + out[key] = bool(val) + elif key in ("manual_close_daily_warn", "freq_30m_count"): + try: + out[key] = max(1, int(val)) + except (TypeError, ValueError): + pass + elif key in ("interval_warn_minutes", "reopen_after_close_minutes"): + try: + out[key] = max(1, int(val)) + except (TypeError, ValueError): + pass + elif isinstance(val, str): + out[key] = val.strip() + return out + + +def load_supervisor_state() -> dict: + data = _load_json(STATE_PATH, {"version": 1, "trading_day": "", "processed": [], "positions": {}, "stats": {}}) + data.setdefault("version", 1) + data.setdefault("processed", []) + data.setdefault("positions", {}) + data.setdefault("stats", {}) + return data + + +def save_supervisor_state(data: dict) -> None: + processed = list(data.get("processed") or []) + if len(processed) > 500: + processed = processed[-500:] + data["processed"] = processed + _atomic_write(STATE_PATH, data) + + +def _trade_event_id(trade: dict) -> str: + return "|".join( + [ + str(trade.get("account_name") or trade.get("account_key") or ""), + str(trade.get("symbol") or ""), + str(trade.get("closed_at") or ""), + str(trade.get("result") or ""), + str(trade.get("pnl_amount") or ""), + ] + ) + + +def classify_close_result(result: str) -> str: + r = (result or "").strip() + if r in PROGRAM_RESULTS: + if r == "止损": + return EVENT_PROGRAM_SL + return EVENT_PROGRAM_TP + if r in MANUAL_CLOSE_RESULTS: + return EVENT_MANUAL_CLOSE + if r in HUB_CLOSE_RESULTS: + return EVENT_HUB_CLOSE + if r in WEAK_RESULTS: + return EVENT_EXTERNAL + return EVENT_EXTERNAL + + +def is_supervised_event(event_type: str) -> bool: + return event_type in (EVENT_OPEN, EVENT_MANUAL_CLOSE, EVENT_HUB_CLOSE) + + +def is_program_event(event_type: str) -> bool: + return event_type in (EVENT_PROGRAM_TP, EVENT_PROGRAM_SL) + + +def _normalize_position_symbol(sym: str) -> str: + """统一合约名,避免 ZEC/USDT 与 ZEC/USDT:USDT 被当成两笔持仓。""" + s = (sym or "").strip().upper() + if not s: + return "" + if s.endswith(":USDT") and "/" in s: + return s.rsplit(":", 1)[0] + return s + + +def _position_key(exchange_id: str, symbol: str, side: str) -> str: + sym = _normalize_position_symbol(symbol) + sd = (side or "long").strip().lower() or "long" + return f"{exchange_id}|{sym}|{sd}" + + +def _position_contracts(pos: dict) -> float: + for key in ("contracts", "contracts_signed", "size"): + try: + v = pos.get(key) + if v is not None and v != "": + return abs(float(v)) + except (TypeError, ValueError): + continue + return 0.0 + + +def collect_position_keys(board_payload: dict | None) -> dict[str, dict]: + out: dict[str, dict] = {} + rows = (board_payload or {}).get("rows") or [] + for row in rows: + if not isinstance(row, dict): + continue + ex_id = str(row.get("id") or row.get("key") or "") + ex_name = str(row.get("name") or row.get("key") or ex_id) + ag = row.get("agent") or {} + for p in ag.get("positions") or []: + if not isinstance(p, dict): + continue + if _position_contracts(p) < 1e-12: + continue + sym = str(p.get("symbol") or "") + side = str(p.get("side") or "").lower() or "long" + key = _position_key(ex_id, sym, side) + out[key] = { + "exchange_id": ex_id, + "exchange_name": ex_name, + "symbol": sym, + "side": side, + "contracts": _position_contracts(p), + } + return out + + +def _board_agent_snapshot_ready(board_payload: dict | None) -> bool: + """监控板各启用账户 agent 快照已就绪(避免空板先入库导致后续持仓误判为新开)。""" + if not isinstance(board_payload, dict) or board_payload.get("ok") is False: + return False + rows = board_payload.get("rows") or [] + if not rows: + return False + seen = 0 + for row in rows: + if not isinstance(row, dict): + continue + if row.get("enabled") is False: + continue + ag = row.get("agent") + if not isinstance(ag, dict): + return False + seen += 1 + return seen > 0 + + +def _entry_contracts(entry: dict | None) -> float: + if not isinstance(entry, dict): + return 0.0 + try: + return float(entry.get("contracts") or 0) + except (TypeError, ValueError): + return 0.0 + + +def detect_new_opens( + prev_positions: dict[str, dict], + curr_positions: dict[str, dict], +) -> list[dict]: + """仅当某合约从空仓变为有仓时视为新开(已有持仓不加仓不算)。""" + events = [] + for key, info in curr_positions.items(): + curr_c = _entry_contracts(info) + if curr_c < 1e-12: + continue + prev_c = _entry_contracts(prev_positions.get(key)) + if prev_c >= 1e-12: + continue + events.append({"event_type": EVENT_OPEN, "event_id": f"open:{key}:{_now_str()[:16]}", **info}) + return events + + +def detect_new_closes( + prev_processed: set[str], + closed_trades: list[dict], +) -> list[dict]: + events = [] + for trade in closed_trades or []: + if not isinstance(trade, dict): + continue + eid = _trade_event_id(trade) + if eid in prev_processed: + continue + event_type = classify_close_result(str(trade.get("result") or "")) + events.append( + { + "event_type": event_type, + "event_id": f"close:{eid}", + "account_name": trade.get("account_name"), + "symbol": trade.get("symbol"), + "direction": trade.get("direction"), + "result": trade.get("result"), + "pnl_amount": trade.get("pnl_amount"), + "closed_at": trade.get("closed_at"), + } + ) + return events + + +def _parse_event_dt(raw: Any) -> Optional[datetime]: + return parse_dt_for_trading_day(raw) + + +def _supervised_close_times(stats: dict, trading_day: str) -> list[datetime]: + rows = (stats.get(trading_day) or {}).get("supervised_closes") or [] + out = [] + for item in rows: + if isinstance(item, dict): + dt = _parse_event_dt(item.get("closed_at") or item.get("at")) + else: + dt = _parse_event_dt(item) + if dt: + out.append(dt) + out.sort() + return out + + +def _record_supervised_event(stats: dict, trading_day: str, event: dict) -> None: + day_stats = stats.setdefault(trading_day, {}) + et = str(event.get("event_type") or "") + if et == EVENT_OPEN: + opens = list(day_stats.get("supervised_opens") or []) + opens.append({"at": _now_str(), "symbol": event.get("symbol")}) + day_stats["supervised_opens"] = opens[-50:] + return + if et not in (EVENT_MANUAL_CLOSE, EVENT_HUB_CLOSE): + return + closes = list(day_stats.get("supervised_closes") or []) + closes.append( + { + "at": _now_str(), + "closed_at": event.get("closed_at"), + "event_type": et, + "pnl_amount": event.get("pnl_amount"), + } + ) + day_stats["supervised_closes"] = closes[-50:] + + +def evaluate_frequency_warnings( + *, + trading_day: str, + event: dict, + stats: dict, + settings: dict, +) -> list[dict]: + if not is_supervised_event(str(event.get("event_type") or "")): + return [] + warnings: list[dict] = [] + day_stats = stats.setdefault(trading_day, {}) + closes = _supervised_close_times(stats, trading_day) + now = datetime.now() + if event.get("event_type") in (EVENT_MANUAL_CLOSE, EVENT_HUB_CLOSE): + evt_dt = _parse_event_dt(event.get("closed_at")) or now + closes = closes + [evt_dt] + closes.sort() + open_count = len(day_stats.get("supervised_opens") or []) + close_count = len(day_stats.get("supervised_closes") or []) + if event.get("event_type") == EVENT_OPEN: + open_count += 1 + elif event.get("event_type") in (EVENT_MANUAL_CLOSE, EVENT_HUB_CLOSE): + close_count += 1 + + interval_min = int(settings.get("interval_warn_minutes") or 15) + daily_warn = int(settings.get("manual_close_daily_warn") or 2) + freq_30m = int(settings.get("freq_30m_count") or 2) + reopen_min = int(settings.get("reopen_after_close_minutes") or 30) + + if event.get("event_type") in (EVENT_MANUAL_CLOSE, EVENT_HUB_CLOSE) and len(closes) >= 2: + prev = closes[-2] + cur = closes[-1] + gap = (cur - prev).total_seconds() / 60.0 + if gap < interval_min: + warnings.append( + { + "rule": "INTERVAL_SHORT", + "message": f"两笔手动/中控平间隔仅 {int(gap)} 分钟(阈值 {interval_min} 分钟)", + } + ) + + recent_closes = [t for t in closes if (now - t).total_seconds() <= 30 * 60] + if event.get("event_type") in (EVENT_MANUAL_CLOSE, EVENT_HUB_CLOSE) and len(recent_closes) >= freq_30m: + warnings.append( + { + "rule": "FREQ_30M", + "message": f"30 分钟内手动/中控平已达 {len(recent_closes)} 笔(阈值 {freq_30m} 笔)", + } + ) + + supervised_total = open_count + close_count + if supervised_total >= daily_warn and event.get("event_type") in ( + EVENT_MANUAL_CLOSE, + EVENT_HUB_CLOSE, + EVENT_OPEN, + ): + if close_count >= daily_warn: + warnings.append( + { + "rule": "DAILY_COUNT", + "message": f"今日手动/中控平 {close_count} 笔(阈值 {daily_warn} 笔),注意过度交易", + } + ) + + if event.get("event_type") == EVENT_OPEN and closes: + last_close = closes[-1] + gap_open = (now - last_close).total_seconds() / 60.0 + if gap_open < reopen_min: + warnings.append( + { + "rule": "REOPEN_FAST", + "message": f"距上一笔手动/中控平仅 {int(gap_open)} 分钟又新开仓(阈值 {reopen_min} 分钟)", + } + ) + + loss_streak = 0 + for item in reversed((stats.get(trading_day) or {}).get("supervised_closes") or []): + try: + pnl = float((item or {}).get("pnl_amount") or 0) + except (TypeError, ValueError): + pnl = 0.0 + if pnl < 0: + loss_streak += 1 + else: + break + if event.get("event_type") in (EVENT_MANUAL_CLOSE, EVENT_HUB_CLOSE): + try: + pnl = float(event.get("pnl_amount") or 0) + except (TypeError, ValueError): + pnl = 0.0 + if pnl < 0: + loss_streak += 1 + else: + loss_streak = 0 + if loss_streak >= 2 and event.get("event_type") in (EVENT_MANUAL_CLOSE, EVENT_HUB_CLOSE): + warnings.append( + { + "rule": "LOSS_STREAK", + "message": f"连续 {loss_streak} 笔手动/中控亏损,先停一停", + } + ) + + deduped = [] + seen = set() + for w in warnings: + key = w.get("rule") + if key in seen: + continue + seen.add(key) + deduped.append(w) + return deduped + + +def event_tag(event_type: str) -> str: + return { + EVENT_OPEN: "监管·开仓", + EVENT_MANUAL_CLOSE: "监管·手动平", + EVENT_HUB_CLOSE: "监管·中控平", + EVENT_PROGRAM_TP: "监管·程序止盈", + EVENT_PROGRAM_SL: "监管·程序止损", + EVENT_EXTERNAL: "监管·外部平", + EVENT_FREQ_WARN: "监管·频率", + }.get(event_type, "监管") + + +def _fmt_pnl_u(pnl: Any) -> str: + try: + v = float(pnl) + sign = "+" if v > 0 else "" + return f"{sign}{v:.4f}".rstrip("0").rstrip(".") + "U" + except (TypeError, ValueError): + return "" + + +def build_supervisor_fallback_reply(event: dict, warnings: list[dict] | None = None) -> str: + """AI 不可用或返回空时的短评语(不展示错误文案)。""" + et = str(event.get("event_type") or "") + sym = str(event.get("symbol") or "—") + ex = str(event.get("exchange_name") or event.get("account_name") or "").strip() + pnl_txt = _fmt_pnl_u(event.get("pnl_amount")) + warn = (warnings or [])[:1] + warn_txt = str(warn[0].get("message") or "").strip() if warn else "" + + if et == EVENT_PROGRAM_SL: + base = f"{sym} 程序止损" + if pnl_txt: + base += f"({pnl_txt})" + base += ",按计划出场是纪律。先歇一会儿,别急着马上再开。" + elif et == EVENT_PROGRAM_TP: + base = f"{sym} 程序止盈" + if pnl_txt: + base += f"({pnl_txt})" + base += ",执行不错。保持节奏,别立刻反手再开一单。" + elif et == EVENT_OPEN: + who = f"{ex} " if ex else "" + base = f"看到 {who}新开 {sym}。动手前确认是不是计划内,别因为上一笔情绪再开。" + elif et == EVENT_HUB_CLOSE: + base = f"中控平了 {sym}" + if pnl_txt: + base += f"({pnl_txt})" + base += "。" + base += f" {warn_txt}" if warn_txt else " 停一停,别连着手痒。" + elif et == EVENT_MANUAL_CLOSE: + base = f"手动平了 {sym}" + if pnl_txt: + base += f"({pnl_txt})" + base += "。" + base += f" {warn_txt}" if warn_txt else " 想好再开下一单。" + elif et == EVENT_FREQ_WARN: + base = warn_txt or "今日操作偏频繁,先休息一会儿。" + else: + base = "收到。确认是否按计划执行,别连续加码。" + return base.strip()[:320] + + +def build_system_message(event: dict, *, trading_day: str, warnings: list[dict] | None = None) -> str: + tag = event_tag(str(event.get("event_type") or "")) + ex = event.get("exchange_name") or event.get("account_name") or "—" + sym = event.get("symbol") or "—" + lines = [f"[{tag}] {ex} · {sym}"] + et = event.get("event_type") + if et == EVENT_OPEN: + side = event.get("side") or event.get("direction") or "" + if side: + lines.append(f"方向:{side}") + elif et in (EVENT_MANUAL_CLOSE, EVENT_HUB_CLOSE, EVENT_PROGRAM_TP, EVENT_PROGRAM_SL, EVENT_EXTERNAL): + res = event.get("result") or "" + pnl = event.get("pnl_amount") + if pnl is not None: + lines.append(f"结果 {res} · 盈亏 {pnl}U") + else: + lines.append(f"结果 {res}") + if event.get("closed_at"): + lines.append(f"平仓时间 {event.get('closed_at')}") + for w in warnings or []: + lines.append(f"⚠ {w.get('message')}") + lines.append(f"交易日 {trading_day}") + return "\n".join(lines) + + +def build_wechat_body( + event: dict, + *, + trading_day: str, + link_base: str, + system_text: str, +) -> str: + link = (link_base or "").strip() + if link: + sep = "&" if "?" in link else "?" + link = f"{link}{sep}day={trading_day}" + body = system_text.replace("\n", "\n") + if link: + body += f"\n详情:{link}" + return body + + +def should_send_wechat(event: dict, settings: dict) -> bool: + if not settings.get("enabled", True): + return False + webhook = (settings.get("wechat_webhook") or "").strip() + if not webhook or "replace-me" in webhook.lower(): + return False + et = str(event.get("event_type") or "") + if is_program_event(et): + return bool(settings.get("wechat_on_program_tp_sl", True)) + if et == EVENT_EXTERNAL: + return False + return True + + +def send_supervisor_wechat( + event: dict, + *, + trading_day: str, + settings: dict, + system_text: str, +) -> bool: + if not should_send_wechat(event, settings): + return False + from lib.common.wechat_notify_lib import send_wechat_webhook + + prefix = (settings.get("wechat_prefix") or "【交易监管】").strip() + body = build_wechat_body( + event, + trading_day=trading_day, + link_base=str(settings.get("wechat_link_base") or ""), + system_text=system_text, + ) + return bool( + send_wechat_webhook( + str(settings.get("wechat_webhook") or ""), + body, + prefix=prefix, + ) + ) + + +_notify_hook: Optional[Callable[[], None]] = None + + +def set_supervisor_notify_hook(fn: Optional[Callable[[], None]]) -> None: + global _notify_hook + _notify_hook = fn + + +def _fire_notify() -> None: + if _notify_hook: + try: + _notify_hook() + except Exception: + pass + + +def process_supervisor_tick( + dashboard_payload: dict | None, + board_payload: dict | None, + settings_root: dict | None, + *, + reset_hour: int = 8, + ai_reply_fn: Optional[Callable[..., str]] = None, +) -> dict[str, Any]: + """单次监管扫描:对比快照、写会话、推微信、可选 AI 评语。""" + from hub_ai.supervisor_store import ( + append_supervisor_ai_message, + append_supervisor_system_message, + ensure_supervisor_session, + ) + + sup_cfg = normalize_supervisor_settings((settings_root or {}).get("supervisor")) + if not sup_cfg.get("enabled", True): + return {"ok": True, "skipped": True, "reason": "disabled"} + + dash = dashboard_payload or {} + trading_day = str(dash.get("trading_day") or current_trading_day(reset_hour=reset_hour)) + state = load_supervisor_state() + if str(state.get("trading_day") or "") != trading_day: + state = { + "version": 1, + "trading_day": trading_day, + "processed": [], + "positions": {}, + "stats": {trading_day: state.get("stats", {}).get(trading_day, {})}, + "positions_baseline_ready": False, + } + + processed = set(str(x) for x in (state.get("processed") or [])) + stats = dict(state.get("stats") or {}) + prev_positions = dict(state.get("positions") or {}) + curr_positions = collect_position_keys(board_payload) + closed_trades = dash.get("closed_trades") or [] + board_ready = _board_agent_snapshot_ready(board_payload) + + if not state.get("positions_baseline_ready"): + for trade in closed_trades: + if isinstance(trade, dict): + processed.add(f"close:{_trade_event_id(trade)}") + if not board_ready: + state["trading_day"] = trading_day + state["processed"] = list(processed) + save_supervisor_state(state) + return {"ok": True, "events": 0, "waiting_board": True, "trading_day": trading_day} + state["trading_day"] = trading_day + state["processed"] = list(processed) + state["positions"] = curr_positions + state["positions_baseline_ready"] = True + state["initialized"] = True + save_supervisor_state(state) + return { + "ok": True, + "events": 0, + "seeded": True, + "trading_day": trading_day, + "positions": len(curr_positions), + } + + raw_events = detect_new_opens(prev_positions, curr_positions) + detect_new_closes( + processed, closed_trades + ) + if not raw_events: + state["positions"] = curr_positions + save_supervisor_state(state) + return {"ok": True, "events": 0} + + session = ensure_supervisor_session(trading_day) + session_id = str(session.get("id") or "") + handled = 0 + + for event in raw_events: + eid = str(event.get("event_id") or uuid.uuid4().hex) + if eid in processed: + continue + et = str(event.get("event_type") or "") + if et == EVENT_EXTERNAL: + processed.add(eid) + continue + + warnings = evaluate_frequency_warnings( + trading_day=trading_day, + event=event, + stats=stats, + settings=sup_cfg, + ) + if is_supervised_event(et): + _record_supervised_event(stats, trading_day, event) + + system_text = build_system_message(event, trading_day=trading_day, warnings=warnings) + append_supervisor_system_message( + session_id, + system_text, + event_type=et, + level="warn" if warnings else "info", + ) + send_supervisor_wechat( + event, + trading_day=trading_day, + settings=sup_cfg, + system_text=system_text, + ) + for w in warnings: + warn_event = { + "event_type": EVENT_FREQ_WARN, + "event_id": f"warn:{eid}:{w.get('rule')}", + **event, + "warn_message": w.get("message"), + } + warn_text = f"[{event_tag(EVENT_FREQ_WARN)}] {w.get('message')}" + append_supervisor_system_message( + session_id, + warn_text, + event_type=EVENT_FREQ_WARN, + level="warn", + ) + send_supervisor_wechat( + warn_event, + trading_day=trading_day, + settings=sup_cfg, + system_text=warn_text, + ) + + if ai_reply_fn and et != EVENT_EXTERNAL: + evt_snapshot = dict(event) + evt_warnings = list(warnings) + + def _ai_bg() -> None: + try: + reply = ai_reply_fn( + event=evt_snapshot, + warnings=evt_warnings, + trading_day=trading_day, + session_id=session_id, + ) + from hub_ai.text_util import is_ai_error_reply + + text = str(reply or "").strip() + if not text or is_ai_error_reply(text): + text = build_supervisor_fallback_reply(evt_snapshot, evt_warnings) + if text: + append_supervisor_ai_message(session_id, text) + _fire_notify() + except Exception: + try: + fb = build_supervisor_fallback_reply(evt_snapshot, evt_warnings) + if fb: + append_supervisor_ai_message(session_id, fb) + _fire_notify() + except Exception: + pass + + threading.Thread(target=_ai_bg, daemon=True).start() + + processed.add(eid) + handled += 1 + + state["trading_day"] = trading_day + state["processed"] = list(processed) + state["positions"] = curr_positions + state["stats"] = stats + save_supervisor_state(state) + if handled: + _fire_notify() + return {"ok": True, "events": handled, "trading_day": trading_day, "session_id": session_id} diff --git a/requirements.txt b/requirements.txt index 18063c8..929087d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ # crypto_monitor 四个 Flask 子项目共用依赖(Binance / Gate / Gate_bot / OKX) # 安装:在各子目录 venv 内执行 pip install -r ../requirements.txt +# 共用 Python 库位于 ../lib/,启动时需将仓库根加入 PYTHONPATH(各 app.py / PM2 已配置) flask>=3.0,<4 requests>=2.31,<3 ccxt>=4.2,<5 diff --git a/scripts/backfill_trend_strategy_snapshots.py b/scripts/backfill_trend_strategy_snapshots.py index f55edcf..56b17b4 100644 --- a/scripts/backfill_trend_strategy_snapshots.py +++ b/scripts/backfill_trend_strategy_snapshots.py @@ -1,248 +1,248 @@ -#!/usr/bin/env python3 -"""补录缺失的趋势回调策略结束快照(strategy_trade_snapshots)。 - -适用:gate_bot 等在计划结束(止盈/止损/手动)时因 strategy_trend_cfg 未注册而漏写快照的历史数据。 -保本移交路径通常已有快照,本脚本默认跳过「已有任意快照」的计划。 - -用法(在仓库根目录,Linux 请用 python3): - python3 scripts/backfill_trend_strategy_snapshots.py \\ - --db crypto_monitor_gate_bot/crypto.db --dry-run - python3 scripts/backfill_trend_strategy_snapshots.py \\ - --db crypto_monitor_gate_bot/crypto.db --apply -""" -from __future__ import annotations - -import argparse -import sqlite3 -import sys -from pathlib import Path - -_REPO_ROOT = Path(__file__).resolve().parents[1] -if str(_REPO_ROOT) not in sys.path: - sys.path.insert(0, str(_REPO_ROOT)) - -from strategy_snapshot_lib import ( # noqa: E402 - STRATEGY_TREND, - init_strategy_snapshot_table, - save_trend_plan_snapshot, -) - -PLAN_STATUS_LABEL = { - "stopped_sl": "止损", - "stopped_tp": "止盈", - "stopped_manual": "手动平仓", - "stopped_handoff": "保本移交", -} - -TRADE_RESULT_LABEL = { - "止损": "止损", - "止盈": "止盈", - "手动平仓": "手动平仓", - "移动止盈": "止盈", - "保本止盈": "止盈", - "强制清仓": "手动平仓", -} - - -def _row_dict(row) -> dict: - if row is None: - return {} - try: - return dict(row) - except Exception: - return {} - - -def infer_exit_price( - direction: str, - entry: float | None, - margin: float | None, - leverage: float | None, - pnl: float | None, -) -> float | None: - """由本地 calc_pnl 口径反推平仓价(供补录快照 exit_price)。""" - try: - trigger = float(entry) - margin_f = float(margin) - lev = float(leverage) - pnl_f = float(pnl) - except (TypeError, ValueError): - return None - if trigger <= 0 or margin_f <= 0 or lev <= 0: - return None - notional = margin_f * lev - if notional <= 0: - return None - ratio = pnl_f / notional - if (direction or "long").strip().lower() == "short": - return round(trigger * (1.0 - ratio), 10) - return round(trigger * (1.0 + ratio), 10) - - -def resolve_result_label(plan: dict, trade: dict | None) -> str: - status = (plan.get("status") or "").strip() - if status in PLAN_STATUS_LABEL: - return PLAN_STATUS_LABEL[status] - if trade: - res = (trade.get("result") or "").strip() - if res in TRADE_RESULT_LABEL: - return TRADE_RESULT_LABEL[res] - if res: - return res - msg = (plan.get("message") or "").strip() - if msg: - return msg[:32] - return "结束" - - -def find_missing_plans( - conn: sqlite3.Connection, - *, - plan_id: int | None = None, - since: str | None = None, -) -> list[dict]: - sql = """ - SELECT p.* - FROM trend_pullback_plans p - WHERE TRIM(COALESCE(p.status, '')) != 'active' - AND NOT EXISTS ( - SELECT 1 FROM strategy_trade_snapshots s - WHERE s.strategy_type = ? AND s.source_id = p.id - ) - """ - params: list[object] = [STRATEGY_TREND] - if plan_id is not None: - sql += " AND p.id = ?" - params.append(int(plan_id)) - if since: - sql += " AND COALESCE(p.opened_at, '') >= ?" - params.append(since.strip()) - sql += " ORDER BY p.id ASC" - rows = conn.execute(sql, params).fetchall() - return [_row_dict(r) for r in rows] - - -def fetch_trade_for_plan(conn: sqlite3.Connection, plan_id: int) -> dict | None: - row = conn.execute( - """ - SELECT * FROM trade_records - WHERE trend_plan_id = ? - ORDER BY COALESCE(closed_at_ms, 0) DESC, id DESC - LIMIT 1 - """, - (int(plan_id),), - ).fetchone() - return _row_dict(row) if row else None - - -def backfill_one(conn: sqlite3.Connection, plan: dict, *, dry_run: bool) -> dict: - plan_id = int(plan["id"]) - trade = fetch_trade_for_plan(conn, plan_id) - result_label = resolve_result_label(plan, trade) - pnl_amount = None - closed_at = None - exit_price = None - entry = plan.get("avg_entry_price") or plan.get("live_price_ref") - margin = plan.get("plan_margin_capital") - leverage = plan.get("leverage") - - if trade: - pnl_amount = trade.get("pnl_amount") - closed_at = trade.get("closed_at") - entry = trade.get("trigger_price") or entry - margin = trade.get("margin_capital") or margin - leverage = trade.get("leverage") or leverage - exit_price = infer_exit_price( - plan.get("direction") or trade.get("direction") or "long", - entry, - margin, - leverage, - pnl_amount, - ) - - info = { - "plan_id": plan_id, - "symbol": plan.get("symbol"), - "status": plan.get("status"), - "result_label": result_label, - "closed_at": closed_at, - "pnl_amount": pnl_amount, - "exit_price": exit_price, - "legs_done": plan.get("legs_done"), - "dca_legs": plan.get("dca_legs"), - "has_trade": bool(trade), - } - - if dry_run: - return info - - save_trend_plan_snapshot( - {}, - conn, - plan, - result_label=result_label, - exit_price=exit_price, - pnl_amount=float(pnl_amount) if pnl_amount is not None else None, - closed_at=closed_at, - ) - return info - - -def main() -> int: - parser = argparse.ArgumentParser( - description="Backfill missing trend_pullback strategy_trade_snapshots rows." - ) - parser.add_argument("--db", required=True, help="Path to instance sqlite db") - parser.add_argument("--plan-id", type=int, help="Only backfill this trend plan id") - parser.add_argument( - "--since", - help="Only plans with opened_at >= YYYY-MM-DD (optional)", - ) - parser.add_argument("--dry-run", action="store_true", help="Preview only (default)") - parser.add_argument("--apply", action="store_true", help="Write snapshots") - args = parser.parse_args() - if not args.dry_run and not args.apply: - args.dry_run = True - - db_path = Path(args.db).expanduser().resolve() - if not db_path.is_file(): - print(f"[ERR] DB not found: {db_path}") - return 1 - - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - init_strategy_snapshot_table(conn) - - missing = find_missing_plans( - conn, plan_id=args.plan_id, since=args.since - ) - if not missing: - print("[INFO] No closed trend plans missing strategy snapshots.") - conn.close() - return 0 - - print(f"[INFO] Found {len(missing)} plan(s) without strategy snapshot.") - applied = 0 - for plan in missing: - info = backfill_one(conn, plan, dry_run=not args.apply) - trade_hint = "有交易记录" if info["has_trade"] else "无交易记录" - print( - f" - plan #{info['plan_id']} {info['symbol']} " - f"status={info['status']} → {info['result_label']} " - f"closed={info['closed_at'] or '—'} pnl={info['pnl_amount']} " - f"补仓 {info['legs_done']}/{info['dca_legs']} ({trade_hint})" - ) - applied += 1 - - if args.apply: - conn.commit() - print(f"[OK] Backfilled {applied} snapshot(s).") - else: - print("[DRY-RUN] No changes written. Re-run with --apply to commit.") - - conn.close() - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) +#!/usr/bin/env python3 +"""补录缺失的趋势回调策略结束快照(strategy_trade_snapshots)。 + +适用:gate_bot 等在计划结束(止盈/止损/手动)时因 strategy_trend_cfg 未注册而漏写快照的历史数据。 +保本移交路径通常已有快照,本脚本默认跳过「已有任意快照」的计划。 + +用法(在仓库根目录,Linux 请用 python3): + python3 scripts/backfill_trend_strategy_snapshots.py \\ + --db crypto_monitor_gate_bot/crypto.db --dry-run + python3 scripts/backfill_trend_strategy_snapshots.py \\ + --db crypto_monitor_gate_bot/crypto.db --apply +""" +from __future__ import annotations + +import argparse +import sqlite3 +import sys +from pathlib import Path + +_REPO_ROOT = Path(__file__).resolve().parents[1] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from lib.strategy.strategy_snapshot_lib import ( # noqa: E402 + STRATEGY_TREND, + init_strategy_snapshot_table, + save_trend_plan_snapshot, +) + +PLAN_STATUS_LABEL = { + "stopped_sl": "止损", + "stopped_tp": "止盈", + "stopped_manual": "手动平仓", + "stopped_handoff": "保本移交", +} + +TRADE_RESULT_LABEL = { + "止损": "止损", + "止盈": "止盈", + "手动平仓": "手动平仓", + "移动止盈": "止盈", + "保本止盈": "止盈", + "强制清仓": "手动平仓", +} + + +def _row_dict(row) -> dict: + if row is None: + return {} + try: + return dict(row) + except Exception: + return {} + + +def infer_exit_price( + direction: str, + entry: float | None, + margin: float | None, + leverage: float | None, + pnl: float | None, +) -> float | None: + """由本地 calc_pnl 口径反推平仓价(供补录快照 exit_price)。""" + try: + trigger = float(entry) + margin_f = float(margin) + lev = float(leverage) + pnl_f = float(pnl) + except (TypeError, ValueError): + return None + if trigger <= 0 or margin_f <= 0 or lev <= 0: + return None + notional = margin_f * lev + if notional <= 0: + return None + ratio = pnl_f / notional + if (direction or "long").strip().lower() == "short": + return round(trigger * (1.0 - ratio), 10) + return round(trigger * (1.0 + ratio), 10) + + +def resolve_result_label(plan: dict, trade: dict | None) -> str: + status = (plan.get("status") or "").strip() + if status in PLAN_STATUS_LABEL: + return PLAN_STATUS_LABEL[status] + if trade: + res = (trade.get("result") or "").strip() + if res in TRADE_RESULT_LABEL: + return TRADE_RESULT_LABEL[res] + if res: + return res + msg = (plan.get("message") or "").strip() + if msg: + return msg[:32] + return "结束" + + +def find_missing_plans( + conn: sqlite3.Connection, + *, + plan_id: int | None = None, + since: str | None = None, +) -> list[dict]: + sql = """ + SELECT p.* + FROM trend_pullback_plans p + WHERE TRIM(COALESCE(p.status, '')) != 'active' + AND NOT EXISTS ( + SELECT 1 FROM strategy_trade_snapshots s + WHERE s.strategy_type = ? AND s.source_id = p.id + ) + """ + params: list[object] = [STRATEGY_TREND] + if plan_id is not None: + sql += " AND p.id = ?" + params.append(int(plan_id)) + if since: + sql += " AND COALESCE(p.opened_at, '') >= ?" + params.append(since.strip()) + sql += " ORDER BY p.id ASC" + rows = conn.execute(sql, params).fetchall() + return [_row_dict(r) for r in rows] + + +def fetch_trade_for_plan(conn: sqlite3.Connection, plan_id: int) -> dict | None: + row = conn.execute( + """ + SELECT * FROM trade_records + WHERE trend_plan_id = ? + ORDER BY COALESCE(closed_at_ms, 0) DESC, id DESC + LIMIT 1 + """, + (int(plan_id),), + ).fetchone() + return _row_dict(row) if row else None + + +def backfill_one(conn: sqlite3.Connection, plan: dict, *, dry_run: bool) -> dict: + plan_id = int(plan["id"]) + trade = fetch_trade_for_plan(conn, plan_id) + result_label = resolve_result_label(plan, trade) + pnl_amount = None + closed_at = None + exit_price = None + entry = plan.get("avg_entry_price") or plan.get("live_price_ref") + margin = plan.get("plan_margin_capital") + leverage = plan.get("leverage") + + if trade: + pnl_amount = trade.get("pnl_amount") + closed_at = trade.get("closed_at") + entry = trade.get("trigger_price") or entry + margin = trade.get("margin_capital") or margin + leverage = trade.get("leverage") or leverage + exit_price = infer_exit_price( + plan.get("direction") or trade.get("direction") or "long", + entry, + margin, + leverage, + pnl_amount, + ) + + info = { + "plan_id": plan_id, + "symbol": plan.get("symbol"), + "status": plan.get("status"), + "result_label": result_label, + "closed_at": closed_at, + "pnl_amount": pnl_amount, + "exit_price": exit_price, + "legs_done": plan.get("legs_done"), + "dca_legs": plan.get("dca_legs"), + "has_trade": bool(trade), + } + + if dry_run: + return info + + save_trend_plan_snapshot( + {}, + conn, + plan, + result_label=result_label, + exit_price=exit_price, + pnl_amount=float(pnl_amount) if pnl_amount is not None else None, + closed_at=closed_at, + ) + return info + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Backfill missing trend_pullback strategy_trade_snapshots rows." + ) + parser.add_argument("--db", required=True, help="Path to instance sqlite db") + parser.add_argument("--plan-id", type=int, help="Only backfill this trend plan id") + parser.add_argument( + "--since", + help="Only plans with opened_at >= YYYY-MM-DD (optional)", + ) + parser.add_argument("--dry-run", action="store_true", help="Preview only (default)") + parser.add_argument("--apply", action="store_true", help="Write snapshots") + args = parser.parse_args() + if not args.dry_run and not args.apply: + args.dry_run = True + + db_path = Path(args.db).expanduser().resolve() + if not db_path.is_file(): + print(f"[ERR] DB not found: {db_path}") + return 1 + + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + init_strategy_snapshot_table(conn) + + missing = find_missing_plans( + conn, plan_id=args.plan_id, since=args.since + ) + if not missing: + print("[INFO] No closed trend plans missing strategy snapshots.") + conn.close() + return 0 + + print(f"[INFO] Found {len(missing)} plan(s) without strategy snapshot.") + applied = 0 + for plan in missing: + info = backfill_one(conn, plan, dry_run=not args.apply) + trade_hint = "有交易记录" if info["has_trade"] else "无交易记录" + print( + f" - plan #{info['plan_id']} {info['symbol']} " + f"status={info['status']} → {info['result_label']} " + f"closed={info['closed_at'] or '—'} pnl={info['pnl_amount']} " + f"补仓 {info['legs_done']}/{info['dca_legs']} ({trade_hint})" + ) + applied += 1 + + if args.apply: + conn.commit() + print(f"[OK] Backfilled {applied} snapshot(s).") + else: + print("[DRY-RUN] No changes written. Re-run with --apply to commit.") + + conn.close() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/backfill_trend_trade_records.py b/scripts/backfill_trend_trade_records.py index 306a5c9..a08e4a1 100644 --- a/scripts/backfill_trend_trade_records.py +++ b/scripts/backfill_trend_trade_records.py @@ -1,188 +1,188 @@ -#!/usr/bin/env python3 -"""补录缺失的趋势回调 trade_records(策略快照已有、交易记录漏写)。 - -典型原因:gate_bot insert_trade_record 曾不接受 entry_reason,_finalize_plan 写快照后插入失败。 - -用法: - python scripts/backfill_trend_trade_records.py --db crypto_monitor_gate_bot/crypto.db --dry-run - python scripts/backfill_trend_trade_records.py --db crypto_monitor_gate_bot/crypto.db --apply -""" -from __future__ import annotations - -import argparse -import json -import sqlite3 -import sys -from pathlib import Path - -_REPO_ROOT = Path(__file__).resolve().parents[1] -if str(_REPO_ROOT) not in sys.path: - sys.path.insert(0, str(_REPO_ROOT)) - -from strategy_snapshot_lib import STRATEGY_TREND # noqa: E402 -from strategy_trade_labels import ENTRY_REASON_TREND_PULLBACK, MONITOR_TYPE_TREND_PULLBACK # noqa: E402 - -STATUS_TO_RESULT = { - "stopped_sl": "止损", - "stopped_tp": "止盈", - "stopped_manual": "手动平仓", -} - - -def _row_dict(row) -> dict: - if row is None: - return {} - try: - return dict(row) - except Exception: - return {} - - -def _hold_minutes(hold_seconds: int) -> int: - try: - return max(0, int(round(float(hold_seconds) / 60.0))) - except (TypeError, ValueError): - return 0 - - -def backfill_one(conn: sqlite3.Connection, snap: dict, *, apply: bool) -> dict: - plan_id = int(snap.get("source_id") or 0) - if plan_id <= 0: - return {"plan_id": plan_id, "skipped": True, "reason": "invalid source_id"} - exists = conn.execute( - "SELECT id FROM trade_records WHERE trend_plan_id=? LIMIT 1", (plan_id,) - ).fetchone() - if exists: - return {"plan_id": plan_id, "skipped": True, "reason": "trade_exists"} - - try: - payload = json.loads(snap.get("snapshot_json") or "{}") - except Exception: - payload = {} - - plan = conn.execute( - "SELECT * FROM trend_pullback_plans WHERE id=?", (plan_id,) - ).fetchone() - plan_d = _row_dict(plan) - - symbol = snap.get("symbol") or plan_d.get("symbol") or payload.get("symbol") - direction = snap.get("direction") or plan_d.get("direction") or payload.get("direction") or "long" - result = (snap.get("result_label") or "").strip() or STATUS_TO_RESULT.get( - plan_d.get("status") or "", "手动平仓" - ) - opened_at = snap.get("opened_at") or plan_d.get("opened_at") - closed_at = snap.get("closed_at") - pnl_amount = snap.get("pnl_amount") - if pnl_amount is None: - pnl_amount = payload.get("pnl_amount") - - trigger_price = payload.get("avg_entry_price") or plan_d.get("avg_entry_price") - stop_loss = payload.get("stop_loss") or plan_d.get("stop_loss") - take_profit = payload.get("take_profit") or plan_d.get("take_profit") - margin_capital = payload.get("plan_margin_capital") or plan_d.get("plan_margin_capital") - leverage = payload.get("leverage") or plan_d.get("leverage") - - opened_ms = plan_d.get("opened_at_ms") - closed_ms = None - - hold_seconds = 0 - if opened_at and closed_at: - try: - from datetime import datetime - - fmt = "%Y-%m-%d %H:%M:%S" - o = datetime.strptime(str(opened_at).strip()[:19], fmt) - c = datetime.strptime(str(closed_at).strip()[:19], fmt) - hold_seconds = max(0, int((c - o).total_seconds())) - except Exception: - hold_seconds = 0 - - row = { - "symbol": symbol, - "monitor_type": MONITOR_TYPE_TREND_PULLBACK, - "direction": direction, - "trigger_price": trigger_price, - "stop_loss": stop_loss, - "initial_stop_loss": plan_d.get("initial_stop_loss") or stop_loss, - "take_profit": take_profit, - "margin_capital": margin_capital, - "leverage": leverage, - "pnl_amount": pnl_amount, - "hold_seconds": hold_seconds, - "trade_style": "trend_pullback", - "result": result, - "opened_at": opened_at, - "opened_at_ms": opened_ms, - "closed_at": closed_at, - "closed_at_ms": closed_ms, - "entry_reason": ENTRY_REASON_TREND_PULLBACK, - "trend_plan_id": plan_id, - } - - if not apply: - return {"plan_id": plan_id, "dry_run": True, "row": row} - - conn.execute( - """INSERT INTO trade_records ( - symbol, monitor_type, direction, trigger_price, stop_loss, initial_stop_loss, - take_profit, margin_capital, leverage, pnl_amount, hold_seconds, trade_style, - hold_minutes, opened_at, opened_at_ms, closed_at, closed_at_ms, result, - entry_reason, trend_plan_id - ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", - ( - row["symbol"], - row["monitor_type"], - row["direction"], - row["trigger_price"], - row["stop_loss"], - row["initial_stop_loss"], - row["take_profit"], - row["margin_capital"], - row["leverage"], - row["pnl_amount"], - row["hold_seconds"], - row["trade_style"], - _hold_minutes(hold_seconds), - row["opened_at"], - row["opened_at_ms"], - row["closed_at"], - row["closed_at_ms"], - row["result"], - row["entry_reason"], - row["trend_plan_id"], - ), - ) - return {"plan_id": plan_id, "inserted": True} - - -def main() -> int: - ap = argparse.ArgumentParser() - ap.add_argument("--db", required=True, help="实例 sqlite 路径") - ap.add_argument("--apply", action="store_true", help="写入数据库(默认 dry-run)") - args = ap.parse_args() - db_path = Path(args.db) - if not db_path.is_file(): - print(f"数据库不存在: {db_path}") - return 1 - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - snaps = conn.execute( - """SELECT * FROM strategy_trade_snapshots - WHERE strategy_type=? ORDER BY id DESC""", - (STRATEGY_TREND,), - ).fetchall() - out = [] - for s in snaps: - r = backfill_one(conn, _row_dict(s), apply=args.apply) - out.append(r) - print(r) - if args.apply: - conn.commit() - conn.close() - inserted = sum(1 for x in out if x.get("inserted")) - print(f"done: inserted={inserted} total_snapshots={len(snaps)} apply={args.apply}") - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) +#!/usr/bin/env python3 +"""补录缺失的趋势回调 trade_records(策略快照已有、交易记录漏写)。 + +典型原因:gate_bot insert_trade_record 曾不接受 entry_reason,_finalize_plan 写快照后插入失败。 + +用法: + python scripts/backfill_trend_trade_records.py --db crypto_monitor_gate_bot/crypto.db --dry-run + python scripts/backfill_trend_trade_records.py --db crypto_monitor_gate_bot/crypto.db --apply +""" +from __future__ import annotations + +import argparse +import json +import sqlite3 +import sys +from pathlib import Path + +_REPO_ROOT = Path(__file__).resolve().parents[1] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from lib.strategy.strategy_snapshot_lib import STRATEGY_TREND # noqa: E402 +from lib.strategy.strategy_trade_labels import ENTRY_REASON_TREND_PULLBACK, MONITOR_TYPE_TREND_PULLBACK # noqa: E402 + +STATUS_TO_RESULT = { + "stopped_sl": "止损", + "stopped_tp": "止盈", + "stopped_manual": "手动平仓", +} + + +def _row_dict(row) -> dict: + if row is None: + return {} + try: + return dict(row) + except Exception: + return {} + + +def _hold_minutes(hold_seconds: int) -> int: + try: + return max(0, int(round(float(hold_seconds) / 60.0))) + except (TypeError, ValueError): + return 0 + + +def backfill_one(conn: sqlite3.Connection, snap: dict, *, apply: bool) -> dict: + plan_id = int(snap.get("source_id") or 0) + if plan_id <= 0: + return {"plan_id": plan_id, "skipped": True, "reason": "invalid source_id"} + exists = conn.execute( + "SELECT id FROM trade_records WHERE trend_plan_id=? LIMIT 1", (plan_id,) + ).fetchone() + if exists: + return {"plan_id": plan_id, "skipped": True, "reason": "trade_exists"} + + try: + payload = json.loads(snap.get("snapshot_json") or "{}") + except Exception: + payload = {} + + plan = conn.execute( + "SELECT * FROM trend_pullback_plans WHERE id=?", (plan_id,) + ).fetchone() + plan_d = _row_dict(plan) + + symbol = snap.get("symbol") or plan_d.get("symbol") or payload.get("symbol") + direction = snap.get("direction") or plan_d.get("direction") or payload.get("direction") or "long" + result = (snap.get("result_label") or "").strip() or STATUS_TO_RESULT.get( + plan_d.get("status") or "", "手动平仓" + ) + opened_at = snap.get("opened_at") or plan_d.get("opened_at") + closed_at = snap.get("closed_at") + pnl_amount = snap.get("pnl_amount") + if pnl_amount is None: + pnl_amount = payload.get("pnl_amount") + + trigger_price = payload.get("avg_entry_price") or plan_d.get("avg_entry_price") + stop_loss = payload.get("stop_loss") or plan_d.get("stop_loss") + take_profit = payload.get("take_profit") or plan_d.get("take_profit") + margin_capital = payload.get("plan_margin_capital") or plan_d.get("plan_margin_capital") + leverage = payload.get("leverage") or plan_d.get("leverage") + + opened_ms = plan_d.get("opened_at_ms") + closed_ms = None + + hold_seconds = 0 + if opened_at and closed_at: + try: + from datetime import datetime + + fmt = "%Y-%m-%d %H:%M:%S" + o = datetime.strptime(str(opened_at).strip()[:19], fmt) + c = datetime.strptime(str(closed_at).strip()[:19], fmt) + hold_seconds = max(0, int((c - o).total_seconds())) + except Exception: + hold_seconds = 0 + + row = { + "symbol": symbol, + "monitor_type": MONITOR_TYPE_TREND_PULLBACK, + "direction": direction, + "trigger_price": trigger_price, + "stop_loss": stop_loss, + "initial_stop_loss": plan_d.get("initial_stop_loss") or stop_loss, + "take_profit": take_profit, + "margin_capital": margin_capital, + "leverage": leverage, + "pnl_amount": pnl_amount, + "hold_seconds": hold_seconds, + "trade_style": "trend_pullback", + "result": result, + "opened_at": opened_at, + "opened_at_ms": opened_ms, + "closed_at": closed_at, + "closed_at_ms": closed_ms, + "entry_reason": ENTRY_REASON_TREND_PULLBACK, + "trend_plan_id": plan_id, + } + + if not apply: + return {"plan_id": plan_id, "dry_run": True, "row": row} + + conn.execute( + """INSERT INTO trade_records ( + symbol, monitor_type, direction, trigger_price, stop_loss, initial_stop_loss, + take_profit, margin_capital, leverage, pnl_amount, hold_seconds, trade_style, + hold_minutes, opened_at, opened_at_ms, closed_at, closed_at_ms, result, + entry_reason, trend_plan_id + ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", + ( + row["symbol"], + row["monitor_type"], + row["direction"], + row["trigger_price"], + row["stop_loss"], + row["initial_stop_loss"], + row["take_profit"], + row["margin_capital"], + row["leverage"], + row["pnl_amount"], + row["hold_seconds"], + row["trade_style"], + _hold_minutes(hold_seconds), + row["opened_at"], + row["opened_at_ms"], + row["closed_at"], + row["closed_at_ms"], + row["result"], + row["entry_reason"], + row["trend_plan_id"], + ), + ) + return {"plan_id": plan_id, "inserted": True} + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--db", required=True, help="实例 sqlite 路径") + ap.add_argument("--apply", action="store_true", help="写入数据库(默认 dry-run)") + args = ap.parse_args() + db_path = Path(args.db) + if not db_path.is_file(): + print(f"数据库不存在: {db_path}") + return 1 + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + snaps = conn.execute( + """SELECT * FROM strategy_trade_snapshots + WHERE strategy_type=? ORDER BY id DESC""", + (STRATEGY_TREND,), + ).fetchall() + out = [] + for s in snaps: + r = backfill_one(conn, _row_dict(s), apply=args.apply) + out.append(r) + print(r) + if args.apply: + conn.commit() + conn.close() + inserted = sum(1 for x in out if x.get("inserted")) + print(f"done: inserted={inserted} total_snapshots={len(snaps)} apply={args.apply}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/build_embed_fragment.py b/scripts/build_embed_fragment.py index 8baec0e..6ce0833 100644 --- a/scripts/build_embed_fragment.py +++ b/scripts/build_embed_fragment.py @@ -1,33 +1,33 @@ -"""Build embed_page_fragment.html from gate index.html.""" -from __future__ import annotations - -from pathlib import Path - -ROOT = Path(__file__).resolve().parent.parent -src_lines = (ROOT / "crypto_monitor_gate" / "templates" / "index.html").read_text( - encoding="utf-8" -).splitlines() - -# 1-based line numbers from index.html -macro_body = src_lines[243:262] # {% macro %} … {% endmacro %} -grid_inner = src_lines[328:736] # inside .grid (exclude outer wrapper) -stats_block = src_lines[738:772] - -out_lines = [ - "{# Hub iframe tab fragment — shared via embed_templates #}", - *macro_body, - '
', - *grid_inner, - "
", - *stats_block, -] - -out_dir = ROOT / "embed_templates" -out_dir.mkdir(exist_ok=True) -text = "\n".join(out_lines) + "\n" -text = text.replace( - "{% include 'order_monitor_rule_tips_gate.html' %}", - "{% include order_rule_tips_tpl %}", -) -(out_dir / "embed_page_fragment.html").write_text(text, encoding="utf-8") -print("wrote", out_dir / "embed_page_fragment.html", "lines", len(out_lines)) +"""Build embed_page_fragment.html from gate index.html.""" +from __future__ import annotations + +from pathlib import Path + +ROOT = Path(__file__).resolve().parent.parent +src_lines = (ROOT / "crypto_monitor_gate" / "templates" / "index.html").read_text( + encoding="utf-8" +).splitlines() + +# 1-based line numbers from index.html +macro_body = src_lines[243:262] # {% macro %} … {% endmacro %} +grid_inner = src_lines[328:736] # inside .grid (exclude outer wrapper) +stats_block = src_lines[738:772] + +out_lines = [ + "{# Hub iframe tab fragment — shared via embed_templates #}", + *macro_body, + '
', + *grid_inner, + "
", + *stats_block, +] + +out_dir = ROOT / "lib" / "instance" / "templates" +out_dir.mkdir(exist_ok=True) +text = "\n".join(out_lines) + "\n" +text = text.replace( + "{% include 'order_monitor_rule_tips_gate.html' %}", + "{% include order_rule_tips_tpl %}", +) +(out_dir / "embed_page_fragment.html").write_text(text, encoding="utf-8") +print("wrote", out_dir / "embed_page_fragment.html", "lines", len(out_lines)) diff --git a/scripts/clear_hub_kline_db.py b/scripts/clear_hub_kline_db.py index 9ae5187..71b8704 100644 --- a/scripts/clear_hub_kline_db.py +++ b/scripts/clear_hub_kline_db.py @@ -1,93 +1,93 @@ -#!/usr/bin/env python3 -"""清空中控 K 线 SQLite 缓存(hub_kline.db),便于清库后全量重拉。 - -用法(Linux 云服务器,在仓库根目录): - python3 scripts/clear_hub_kline_db.py --dry-run - python3 scripts/clear_hub_kline_db.py --apply - python3 scripts/clear_hub_kline_db.py --apply --exchange binance --symbol BTC/USDT --timeframe 15m - -默认库路径:环境变量 HUB_KLINE_DB_PATH,或 manual_trading_hub/data/hub_kline.db -""" -from __future__ import annotations - -import argparse -import os -import sys -from pathlib import Path - -ROOT = Path(__file__).resolve().parents[1] -sys.path.insert(0, str(ROOT)) - -from hub_kline_store import ( # noqa: E402 - clear_all_bars, - clear_series_bars, - default_db_path, - init_db, -) - - -def main() -> int: - parser = argparse.ArgumentParser(description="Clear manual-trading-hub K-line SQLite cache.") - parser.add_argument( - "--db", - default=os.getenv("HUB_KLINE_DB_PATH", "").strip() or str(default_db_path()), - help="hub_kline.db path", - ) - parser.add_argument("--exchange", default="", help="exchange_key, e.g. binance") - parser.add_argument("--symbol", default="", help="symbol, e.g. BTC/USDT") - parser.add_argument("--timeframe", default="", help="optional timeframe, e.g. 15m") - parser.add_argument("--dry-run", action="store_true", help="count only") - parser.add_argument("--apply", action="store_true", help="execute delete") - args = parser.parse_args() - - db_path = Path(args.db) - if not db_path.is_file(): - print(f"DB not found: {db_path}", file=sys.stderr) - return 1 - - init_db(db_path) - ex = (args.exchange or "").strip().lower() - sym = (args.symbol or "").strip().upper() - tf = (args.timeframe or "").strip().lower() or None - - if args.dry_run and not args.apply: - import sqlite3 - - conn = sqlite3.connect(str(db_path)) - try: - if ex and sym: - if tf: - n = conn.execute( - "SELECT COUNT(*) FROM ohlcv_bars WHERE exchange_key=? AND symbol=? AND timeframe=?", - (ex, sym, tf), - ).fetchone()[0] - print(f"would delete series rows: {n} ({ex} {sym} {tf})") - else: - n = conn.execute( - "SELECT COUNT(*) FROM ohlcv_bars WHERE exchange_key=? AND symbol=?", - (ex, sym), - ).fetchone()[0] - print(f"would delete symbol rows: {n} ({ex} {sym} all tf)") - else: - n = conn.execute("SELECT COUNT(*) FROM ohlcv_bars").fetchone()[0] - print(f"would delete all ohlcv_bars rows: {n}") - finally: - conn.close() - return 0 - - if not args.apply: - print("Specify --apply to delete (or --dry-run to preview).", file=sys.stderr) - return 1 - - if ex and sym: - removed = clear_series_bars(ex, sym, tf, db_path) - scope = f"{ex} {sym}" + (f" {tf}" if tf else " (all timeframes)") - print(f"cleared {removed} rows for {scope}") - else: - removed = clear_all_bars(db_path) - print(f"cleared all {removed} ohlcv_bars rows from {db_path}") - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) +#!/usr/bin/env python3 +"""清空中控 K 线 SQLite 缓存(hub_kline.db),便于清库后全量重拉。 + +用法(Linux 云服务器,在仓库根目录): + python3 scripts/clear_hub_kline_db.py --dry-run + python3 scripts/clear_hub_kline_db.py --apply + python3 scripts/clear_hub_kline_db.py --apply --exchange binance --symbol BTC/USDT --timeframe 15m + +默认库路径:环境变量 HUB_KLINE_DB_PATH,或 manual_trading_hub/data/hub_kline.db +""" +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT)) + +from lib.hub.hub_kline_store import ( # noqa: E402 + clear_all_bars, + clear_series_bars, + default_db_path, + init_db, +) + + +def main() -> int: + parser = argparse.ArgumentParser(description="Clear manual-trading-hub K-line SQLite cache.") + parser.add_argument( + "--db", + default=os.getenv("HUB_KLINE_DB_PATH", "").strip() or str(default_db_path()), + help="hub_kline.db path", + ) + parser.add_argument("--exchange", default="", help="exchange_key, e.g. binance") + parser.add_argument("--symbol", default="", help="symbol, e.g. BTC/USDT") + parser.add_argument("--timeframe", default="", help="optional timeframe, e.g. 15m") + parser.add_argument("--dry-run", action="store_true", help="count only") + parser.add_argument("--apply", action="store_true", help="execute delete") + args = parser.parse_args() + + db_path = Path(args.db) + if not db_path.is_file(): + print(f"DB not found: {db_path}", file=sys.stderr) + return 1 + + init_db(db_path) + ex = (args.exchange or "").strip().lower() + sym = (args.symbol or "").strip().upper() + tf = (args.timeframe or "").strip().lower() or None + + if args.dry_run and not args.apply: + import sqlite3 + + conn = sqlite3.connect(str(db_path)) + try: + if ex and sym: + if tf: + n = conn.execute( + "SELECT COUNT(*) FROM ohlcv_bars WHERE exchange_key=? AND symbol=? AND timeframe=?", + (ex, sym, tf), + ).fetchone()[0] + print(f"would delete series rows: {n} ({ex} {sym} {tf})") + else: + n = conn.execute( + "SELECT COUNT(*) FROM ohlcv_bars WHERE exchange_key=? AND symbol=?", + (ex, sym), + ).fetchone()[0] + print(f"would delete symbol rows: {n} ({ex} {sym} all tf)") + else: + n = conn.execute("SELECT COUNT(*) FROM ohlcv_bars").fetchone()[0] + print(f"would delete all ohlcv_bars rows: {n}") + finally: + conn.close() + return 0 + + if not args.apply: + print("Specify --apply to delete (or --dry-run to preview).", file=sys.stderr) + return 1 + + if ex and sym: + removed = clear_series_bars(ex, sym, tf, db_path) + scope = f"{ex} {sym}" + (f" {tf}" if tf else " (all timeframes)") + print(f"cleared {removed} rows for {scope}") + else: + removed = clear_all_bars(db_path) + print(f"cleared all {removed} ohlcv_bars rows from {db_path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/dedupe_strategy_snapshots.py b/scripts/dedupe_strategy_snapshots.py index bff8690..7a2a36e 100644 --- a/scripts/dedupe_strategy_snapshots.py +++ b/scripts/dedupe_strategy_snapshots.py @@ -1,67 +1,67 @@ -#!/usr/bin/env python3 -"""清理 strategy_trade_snapshots 重复行(同计划 + 同结果仅保留 id 最大的一条)。 - -用法(在实例目录,如 crypto_monitor_gate_bot): - python ../scripts/dedupe_strategy_snapshots.py - python ../scripts/dedupe_strategy_snapshots.py --db crypto.db -""" -from __future__ import annotations - -import argparse -import os -import sqlite3 -import sys -from pathlib import Path - -ROOT = Path(__file__).resolve().parents[1] -sys.path.insert(0, str(ROOT)) - -from strategy_snapshot_lib import dedupe_strategy_snapshots, init_strategy_snapshot_table # noqa: E402 - - -def main() -> int: - parser = argparse.ArgumentParser(description="Dedupe strategy_trade_snapshots rows.") - parser.add_argument( - "--db", - default=os.getenv("DB_PATH", "crypto.db"), - help="SQLite database path (default: DB_PATH or crypto.db)", - ) - parser.add_argument("--dry-run", action="store_true", help="Count only, do not delete") - args = parser.parse_args() - - db_path = Path(args.db) - if not db_path.is_file(): - print(f"DB not found: {db_path}", file=sys.stderr) - return 1 - - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - init_strategy_snapshot_table(conn) - before = conn.execute("SELECT COUNT(*) AS c FROM strategy_trade_snapshots").fetchone()["c"] - dup_groups = conn.execute( - """SELECT strategy_type, source_id, result_label, COUNT(*) AS n - FROM strategy_trade_snapshots - GROUP BY strategy_type, source_id, result_label - HAVING n > 1 - ORDER BY n DESC""" - ).fetchall() - extra = sum(int(r["n"]) - 1 for r in dup_groups) - print(f"snapshots total={before}, duplicate rows to remove={extra}, groups={len(dup_groups)}") - for r in dup_groups[:20]: - print( - f" {r['strategy_type']} plan={r['source_id']} " - f"{r['result_label']} x{r['n']}" - ) - if args.dry_run: - conn.close() - return 0 - removed = dedupe_strategy_snapshots(conn) - conn.commit() - after = conn.execute("SELECT COUNT(*) AS c FROM strategy_trade_snapshots").fetchone()["c"] - conn.close() - print(f"removed={removed}, remaining={after}") - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) +#!/usr/bin/env python3 +"""清理 strategy_trade_snapshots 重复行(同计划 + 同结果仅保留 id 最大的一条)。 + +用法(在实例目录,如 crypto_monitor_gate_bot): + python ../scripts/dedupe_strategy_snapshots.py + python ../scripts/dedupe_strategy_snapshots.py --db crypto.db +""" +from __future__ import annotations + +import argparse +import os +import sqlite3 +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT)) + +from lib.strategy.strategy_snapshot_lib import dedupe_strategy_snapshots, init_strategy_snapshot_table # noqa: E402 + + +def main() -> int: + parser = argparse.ArgumentParser(description="Dedupe strategy_trade_snapshots rows.") + parser.add_argument( + "--db", + default=os.getenv("DB_PATH", "crypto.db"), + help="SQLite database path (default: DB_PATH or crypto.db)", + ) + parser.add_argument("--dry-run", action="store_true", help="Count only, do not delete") + args = parser.parse_args() + + db_path = Path(args.db) + if not db_path.is_file(): + print(f"DB not found: {db_path}", file=sys.stderr) + return 1 + + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + init_strategy_snapshot_table(conn) + before = conn.execute("SELECT COUNT(*) AS c FROM strategy_trade_snapshots").fetchone()["c"] + dup_groups = conn.execute( + """SELECT strategy_type, source_id, result_label, COUNT(*) AS n + FROM strategy_trade_snapshots + GROUP BY strategy_type, source_id, result_label + HAVING n > 1 + ORDER BY n DESC""" + ).fetchall() + extra = sum(int(r["n"]) - 1 for r in dup_groups) + print(f"snapshots total={before}, duplicate rows to remove={extra}, groups={len(dup_groups)}") + for r in dup_groups[:20]: + print( + f" {r['strategy_type']} plan={r['source_id']} " + f"{r['result_label']} x{r['n']}" + ) + if args.dry_run: + conn.close() + return 0 + removed = dedupe_strategy_snapshots(conn) + conn.commit() + after = conn.execute("SELECT COUNT(*) AS c FROM strategy_trade_snapshots").fetchone()["c"] + conn.close() + print(f"removed={removed}, remaining={after}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/extract_instance_page_assets.py b/scripts/extract_instance_page_assets.py index 7902db6..caff54e 100644 --- a/scripts/extract_instance_page_assets.py +++ b/scripts/extract_instance_page_assets.py @@ -1,49 +1,49 @@ -"""One-off: extract instance_page.css / instance_page_boot.js from gate index.html.""" -from __future__ import annotations - -import re -from pathlib import Path - -ROOT = Path(__file__).resolve().parent.parent -src = ROOT / "crypto_monitor_gate" / "templates" / "index.html" -text = src.read_text(encoding="utf-8") - -m = re.search(r"", text, re.S) -if m: - (ROOT / "static" / "instance_page.css").write_text(m.group(1).strip() + "\n", encoding="utf-8") - -marker = '' -if marker in text: - part = text.split(marker, 1)[1] - m2 = re.search(r"\s*", part, re.S) - if m2: - boot = m2.group(1).strip() - boot = boot.replace( - "setInterval(refreshAccountSnapshot, {{ balance_refresh_seconds * 1000 }});", - "setInterval(refreshAccountSnapshot, Number(document.body.dataset.balanceRefreshMs || 30000));", - ) - boot = boot.replace( - "setInterval(refreshPriceSnapshotConditional, {{ price_refresh_seconds * 1000 }});", - "setInterval(refreshPriceSnapshotConditional, Number(document.body.dataset.priceRefreshMs || 5000));", - ) - (ROOT / "static" / "instance_page_boot.js").write_text(boot + "\n", encoding="utf-8") - - part2 = text.split(marker, 1)[1] - m3 = re.search(r"\s*", part2, re.S) - if m3: - boot_tpl = m3.group(1).strip() - boot_tpl = boot_tpl.replace( - "setInterval(refreshAccountSnapshot, {{ balance_refresh_seconds * 1000 }});", - "setInterval(refreshAccountSnapshot, Number(document.body.dataset.balanceRefreshMs || 30000));", - ) - boot_tpl = boot_tpl.replace( - "setInterval(refreshPriceSnapshotConditional, {{ price_refresh_seconds * 1000 }});", - "setInterval(refreshPriceSnapshotConditional, Number(document.body.dataset.priceRefreshMs || 5000));", - ) - embed_dir = ROOT / "embed_templates" - embed_dir.mkdir(exist_ok=True) - (embed_dir / "embed_boot_scripts.html").write_text( - "\n", encoding="utf-8" - ) - -print("done") +"""One-off: extract instance_page.css / instance_page_boot.js from gate index.html.""" +from __future__ import annotations + +import re +from pathlib import Path + +ROOT = Path(__file__).resolve().parent.parent +src = ROOT / "crypto_monitor_gate" / "templates" / "index.html" +text = src.read_text(encoding="utf-8") + +m = re.search(r"", text, re.S) +if m: + (ROOT / "lib" / "common" / "static" / "instance_page.css").write_text(m.group(1).strip() + "\n", encoding="utf-8") + +marker = '' +if marker in text: + part = text.split(marker, 1)[1] + m2 = re.search(r"\s*", part, re.S) + if m2: + boot = m2.group(1).strip() + boot = boot.replace( + "setInterval(refreshAccountSnapshot, {{ balance_refresh_seconds * 1000 }});", + "setInterval(refreshAccountSnapshot, Number(document.body.dataset.balanceRefreshMs || 30000));", + ) + boot = boot.replace( + "setInterval(refreshPriceSnapshotConditional, {{ price_refresh_seconds * 1000 }});", + "setInterval(refreshPriceSnapshotConditional, Number(document.body.dataset.priceRefreshMs || 5000));", + ) + (ROOT / "lib" / "common" / "static" / "instance_page_boot.js").write_text(boot + "\n", encoding="utf-8") + + part2 = text.split(marker, 1)[1] + m3 = re.search(r"\s*", part2, re.S) + if m3: + boot_tpl = m3.group(1).strip() + boot_tpl = boot_tpl.replace( + "setInterval(refreshAccountSnapshot, {{ balance_refresh_seconds * 1000 }});", + "setInterval(refreshAccountSnapshot, Number(document.body.dataset.balanceRefreshMs || 30000));", + ) + boot_tpl = boot_tpl.replace( + "setInterval(refreshPriceSnapshotConditional, {{ price_refresh_seconds * 1000 }});", + "setInterval(refreshPriceSnapshotConditional, Number(document.body.dataset.priceRefreshMs || 5000));", + ) + embed_dir = ROOT / "lib" / "instance" / "templates" + embed_dir.mkdir(exist_ok=True) + (embed_dir / "embed_boot_scripts.html").write_text( + "\n", encoding="utf-8" + ) + +print("done") diff --git a/scripts/fix_trend_handoff_monitor_type.py b/scripts/fix_trend_handoff_monitor_type.py index f6be631..8983a72 100644 --- a/scripts/fix_trend_handoff_monitor_type.py +++ b/scripts/fix_trend_handoff_monitor_type.py @@ -1,78 +1,78 @@ -#!/usr/bin/env python3 -"""修正趋势保本移交后 monitor_type 仍为「下单监控」的历史数据。""" -from __future__ import annotations - -import argparse -import sqlite3 -from pathlib import Path - -from strategy_trade_labels import MONITOR_TYPE_TREND_PULLBACK - - -def main() -> int: - parser = argparse.ArgumentParser(description="Fix trend handoff order/trade monitor_type labels.") - parser.add_argument("--db", required=True, help="Path to instance sqlite db") - parser.add_argument("--dry-run", action="store_true", help="Preview only") - parser.add_argument("--apply", action="store_true", help="Apply updates") - args = parser.parse_args() - if not args.dry_run and not args.apply: - args.dry_run = True - - db_path = Path(args.db).expanduser().resolve() - if not db_path.is_file(): - print(f"[ERR] DB not found: {db_path}") - return 1 - - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - cur = conn.cursor() - - cur.execute( - """ - SELECT COUNT(*) AS c FROM order_monitors - WHERE trend_plan_id IS NOT NULL AND trend_plan_id > 0 - AND (monitor_type IS NULL OR TRIM(monitor_type) = '' OR monitor_type = '下单监控') - """ - ) - om_n = int(cur.fetchone()["c"]) - cur.execute( - """ - SELECT COUNT(*) AS c FROM trade_records - WHERE trend_plan_id IS NOT NULL AND trend_plan_id > 0 - AND (monitor_type IS NULL OR TRIM(monitor_type) = '' OR monitor_type = '下单监控') - """ - ) - tr_n = int(cur.fetchone()["c"]) - print(f"[INFO] order_monitors to fix: {om_n}") - print(f"[INFO] trade_records to fix: {tr_n}") - - if args.dry_run: - conn.close() - return 0 - - cur.execute( - """ - UPDATE order_monitors - SET monitor_type=? - WHERE trend_plan_id IS NOT NULL AND trend_plan_id > 0 - AND (monitor_type IS NULL OR TRIM(monitor_type) = '' OR monitor_type = '下单监控') - """, - (MONITOR_TYPE_TREND_PULLBACK,), - ) - cur.execute( - """ - UPDATE trade_records - SET monitor_type=? - WHERE trend_plan_id IS NOT NULL AND trend_plan_id > 0 - AND (monitor_type IS NULL OR TRIM(monitor_type) = '' OR monitor_type = '下单监控') - """, - (MONITOR_TYPE_TREND_PULLBACK,), - ) - conn.commit() - conn.close() - print("[OK] Applied.") - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) +#!/usr/bin/env python3 +"""修正趋势保本移交后 monitor_type 仍为「下单监控」的历史数据。""" +from __future__ import annotations + +import argparse +import sqlite3 +from pathlib import Path + +from lib.strategy.strategy_trade_labels import MONITOR_TYPE_TREND_PULLBACK + + +def main() -> int: + parser = argparse.ArgumentParser(description="Fix trend handoff order/trade monitor_type labels.") + parser.add_argument("--db", required=True, help="Path to instance sqlite db") + parser.add_argument("--dry-run", action="store_true", help="Preview only") + parser.add_argument("--apply", action="store_true", help="Apply updates") + args = parser.parse_args() + if not args.dry_run and not args.apply: + args.dry_run = True + + db_path = Path(args.db).expanduser().resolve() + if not db_path.is_file(): + print(f"[ERR] DB not found: {db_path}") + return 1 + + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + cur = conn.cursor() + + cur.execute( + """ + SELECT COUNT(*) AS c FROM order_monitors + WHERE trend_plan_id IS NOT NULL AND trend_plan_id > 0 + AND (monitor_type IS NULL OR TRIM(monitor_type) = '' OR monitor_type = '下单监控') + """ + ) + om_n = int(cur.fetchone()["c"]) + cur.execute( + """ + SELECT COUNT(*) AS c FROM trade_records + WHERE trend_plan_id IS NOT NULL AND trend_plan_id > 0 + AND (monitor_type IS NULL OR TRIM(monitor_type) = '' OR monitor_type = '下单监控') + """ + ) + tr_n = int(cur.fetchone()["c"]) + print(f"[INFO] order_monitors to fix: {om_n}") + print(f"[INFO] trade_records to fix: {tr_n}") + + if args.dry_run: + conn.close() + return 0 + + cur.execute( + """ + UPDATE order_monitors + SET monitor_type=? + WHERE trend_plan_id IS NOT NULL AND trend_plan_id > 0 + AND (monitor_type IS NULL OR TRIM(monitor_type) = '' OR monitor_type = '下单监控') + """, + (MONITOR_TYPE_TREND_PULLBACK,), + ) + cur.execute( + """ + UPDATE trade_records + SET monitor_type=? + WHERE trend_plan_id IS NOT NULL AND trend_plan_id > 0 + AND (monitor_type IS NULL OR TRIM(monitor_type) = '' OR monitor_type = '下单监控') + """, + (MONITOR_TYPE_TREND_PULLBACK,), + ) + conn.commit() + conn.close() + print("[OK] Applied.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/migrate_to_lib.py b/scripts/migrate_to_lib.py new file mode 100644 index 0000000..d7b2ed3 --- /dev/null +++ b/scripts/migrate_to_lib.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python3 +"""One-shot: move root shared modules into lib/ and rewrite imports.""" +from __future__ import annotations + +import re +import subprocess +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parent.parent + +PACKAGE_FILES: dict[str, list[str]] = { + "strategy": [ + "strategy_config.py", + "strategy_db.py", + "strategy_exchange_base.py", + "strategy_exchange_binance.py", + "strategy_exchange_gate.py", + "strategy_exchange_okx.py", + "strategy_records_register.py", + "strategy_register.py", + "strategy_roll_lib.py", + "strategy_roll_monitor_lib.py", + "strategy_roll_ui_lib.py", + "strategy_snapshot_lib.py", + "strategy_trade_labels.py", + "strategy_trend_exchange.py", + "strategy_trend_lib.py", + "strategy_trend_register.py", + "strategy_ui.py", + "strategy_wechat_notify.py", + ], + "key_monitor": [ + "key_monitor_full_margin_lib.py", + "key_monitor_lib.py", + "key_monitor_schema_lib.py", + "key_sl_tp_lib.py", + "fib_key_monitor_lib.py", + "false_breakout_key_monitor_lib.py", + "trigger_entry_key_monitor_lib.py", + ], + "trade": [ + "trade_result_lib.py", + "trade_exchange_stats_lib.py", + "trade_stats_calendar_lib.py", + "order_monitor_display_lib.py", + "position_sizing_lib.py", + "account_risk_lib.py", + "manual_sltp_lib.py", + "time_close_lib.py", + "daily_open_limit_lib.py", + ], + "hub": [ + "hub_auth.py", + "hub_bridge.py", + "hub_calculator_lib.py", + "hub_calculator_market_lib.py", + "hub_entry_plan_lib.py", + "hub_fund_history_lib.py", + "hub_host_status_lib.py", + "hub_kline_store.py", + "hub_macro_calendar_lib.py", + "hub_market_info_lib.py", + "hub_ohlcv_lib.py", + "hub_position_metrics.py", + "hub_sso.py", + "hub_symbol_archive_lib.py", + "hub_trades_lib.py", + "hub_volume_rank_lib.py", + ], + "ai": [ + "ai_client.py", + "ai_review_lib.py", + ], + "instance": [ + "instance_embed_context_lib.py", + "instance_embed_lib.py", + "instance_nav_lib.py", + "focus_chart_lib.py", + "journal_chart_lib.py", + ], + "exchange": [ + "gate_transfer_lib.py", + "gate_position_history_lib.py", + "okx_orders_lib.py", + ], + "common": [ + "form_submit_lib.py", + "history_window_lib.py", + "wechat_notify_lib.py", + "auto_transfer_daily_lib.py", + ], +} + +DIR_MOVES: list[tuple[str, str]] = [ + ("strategy_templates", "lib/strategy/templates"), + ("embed_templates", "lib/instance/templates"), + ("static", "lib/common/static"), +] + +MODULE_TO_LIB: dict[str, str] = {} +for pkg, files in PACKAGE_FILES.items(): + for fname in files: + MODULE_TO_LIB[fname[:-3]] = f"lib.{pkg}.{fname[:-3]}" + +IMPORT_FROM_RE = re.compile( + r"^(\s*)from\s+(" + "|".join(re.escape(m) for m in sorted(MODULE_TO_LIB, key=len, reverse=True)) + r")\s+import\s+", + re.MULTILINE, +) +IMPORT_BARE_RE = re.compile( + r"^(\s*)import\s+(" + "|".join(re.escape(m) for m in sorted(MODULE_TO_LIB, key=len, reverse=True)) + r")(\s|$)", + re.MULTILINE, +) + + +def git_mv(src: Path, dst: Path) -> None: + dst.parent.mkdir(parents=True, exist_ok=True) + if not src.exists(): + if dst.exists(): + return + raise FileNotFoundError(src) + subprocess.run(["git", "mv", str(src), str(dst)], cwd=ROOT, check=True) + + +def move_files() -> None: + (ROOT / "lib").mkdir(exist_ok=True) + for pkg in PACKAGE_FILES: + (ROOT / "lib" / pkg).mkdir(parents=True, exist_ok=True) + init = ROOT / "lib" / pkg / "__init__.py" + if not init.exists(): + init.write_text('"""Shared library package."""\n', encoding="utf-8") + + lib_init = ROOT / "lib" / "__init__.py" + if not lib_init.exists(): + lib_init.write_text('"""crypto_monitor shared libraries."""\n', encoding="utf-8") + + paths_py = ROOT / "lib" / "paths.py" + if not paths_py.exists(): + paths_py.write_text( + '''"""Repository path helpers for lib/ assets.""" +from __future__ import annotations + +import os +from pathlib import Path + +LIB_DIR = Path(__file__).resolve().parent +REPO_ROOT = LIB_DIR.parent + + +def strategy_templates_dir(repo_root: str | Path | None = None) -> str: + root = Path(repo_root) if repo_root is not None else REPO_ROOT + return str(root / "lib" / "strategy" / "templates") + + +def embed_templates_dir(repo_root: str | Path | None = None) -> str: + root = Path(repo_root) if repo_root is not None else REPO_ROOT + return str(root / "lib" / "instance" / "templates") + + +def common_static_dir(repo_root: str | Path | None = None) -> str: + root = Path(repo_root) if repo_root is not None else REPO_ROOT + return str(root / "lib" / "common" / "static") +''', + encoding="utf-8", + ) + + for pkg, files in PACKAGE_FILES.items(): + for fname in files: + git_mv(ROOT / fname, ROOT / "lib" / pkg / fname) + + for src_rel, dst_rel in DIR_MOVES: + git_mv(ROOT / src_rel, ROOT / dst_rel) + + +def rewrite_imports_in_text(text: str) -> str: + def from_repl(m: re.Match) -> str: + mod = m.group(2) + return f"{m.group(1)}from {MODULE_TO_LIB[mod]} import " + + def bare_repl(m: re.Match) -> str: + mod = m.group(2) + return f"{m.group(1)}import {MODULE_TO_LIB[mod]}{m.group(3)}" + + text = IMPORT_FROM_RE.sub(from_repl, text) + text = IMPORT_BARE_RE.sub(bare_repl, text) + return text + + +def patch_path_literals(text: str) -> str: + replacements = [ + ('os.path.join(repo_root, "strategy_templates")', 'strategy_templates_dir(repo_root)'), + ('os.path.join(repo_root, "embed_templates")', 'embed_templates_dir(repo_root)'), + ('os.path.join(os.path.dirname(BASE_DIR), "static")', 'common_static_dir(os.path.dirname(BASE_DIR))'), + ('_REPO_ROOT / "static"', '_REPO_ROOT / "lib" / "common" / "static"'), + ('ROOT / "strategy_templates"', 'ROOT / "lib" / "strategy" / "templates"'), + ('ROOT / "embed_templates"', 'ROOT / "lib" / "instance" / "templates"'), + ('ROOT / "static"', 'ROOT / "lib" / "common" / "static"'), + ] + for old, new in replacements: + text = text.replace(old, new) + return text + + +def ensure_paths_import(text: str, filepath: Path) -> str: + needs = [] + if "strategy_templates_dir(" in text and "from lib.paths import" not in text: + needs.append("strategy_templates_dir") + if "embed_templates_dir(" in text and "from lib.paths import" not in text: + needs.append("embed_templates_dir") + if "common_static_dir(" in text and "from lib.paths import" not in text: + needs.append("common_static_dir") + if not needs: + return text + imp = f"from lib.paths import {', '.join(sorted(set(needs)))}\n" + if text.startswith('"""') or text.startswith("'''"): + end = text.find('"""', 3) if text.startswith('"""') else text.find("'''", 3) + if end != -1: + end += 3 + return text[:end] + "\n\n" + imp + text[end + 1 :] + if text.startswith("from __future__"): + lines = text.splitlines(keepends=True) + i = 0 + while i < len(lines) and ( + lines[i].startswith("from __future__") or lines[i].strip() == "" + ): + i += 1 + return "".join(lines[:i]) + imp + "".join(lines[i:]) + return imp + text + + +def rewrite_all_py_files() -> None: + skip = {ROOT / "scripts" / "migrate_to_lib.py"} + for path in ROOT.rglob("*.py"): + if path in skip or ".venv" in path.parts or "__pycache__" in path.parts: + continue + original = path.read_text(encoding="utf-8") + updated = rewrite_imports_in_text(original) + updated = patch_path_literals(updated) + updated = ensure_paths_import(updated, path) + if updated != original: + path.write_text(updated, encoding="utf-8") + + +def main() -> int: + move_files() + rewrite_all_py_files() + print("Migration complete.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/patch_position_sizing_to_exchanges.py b/scripts/patch_position_sizing_to_exchanges.py index e765e47..8ecf60c 100644 --- a/scripts/patch_position_sizing_to_exchanges.py +++ b/scripts/patch_position_sizing_to_exchanges.py @@ -1,197 +1,197 @@ -#!/usr/bin/env python3 -"""一次性:为 okx/gate/gate_bot 注入与 binance 一致的计仓模式补丁(已 patch 过则跳过)。""" -from __future__ import annotations - -import re -from pathlib import Path - -ROOT = Path(__file__).resolve().parents[1] - -IMPORT_BLOCK = '''from position_sizing_lib import ( - OPEN_SOURCE_KEY_AUTO, - OPEN_SOURCE_MANUAL, - assert_open_source_allowed, - compute_full_margin_sizing, - full_margin_requires_flat_position, - is_full_margin_mode, - leverage_for_full_margin, - load_position_sizing_mode, - mode_label_zh, -) -from key_monitor_full_margin_lib import ( - monitor_type_disallowed_in_full_margin, - purge_disallowed_key_monitors, -) -''' - -ENV_LINE = ( - "# 计仓模式:risk=以损定仓(默认);full_margin=合约可用×比例全仓杠杆(仅 env 切换,须无仓)\n" - "POSITION_SIZING_MODE = load_position_sizing_mode()\n" -) - -PURGE_FN = ''' - -def _purge_key_monitors_if_full_margin(): - if not is_full_margin_mode(POSITION_SIZING_MODE): - return - conn = get_db() - try: - cancel = globals().get("_cancel_fib_monitor_limit") - if not callable(cancel): - cancel = lambda _row: None - purge_disallowed_key_monitors( - conn, - sizing_mode=POSITION_SIZING_MODE, - select_rows=lambda c: c.execute("SELECT * FROM key_monitors").fetchall(), - cancel_fib_limit=cancel, - delete_monitor=lambda c, kid: c.execute("DELETE FROM key_monitors WHERE id=?", (kid,)), - send_wechat=send_wechat_msg, - ) - conn.commit() - except Exception as e: - print(f"[full_margin] purge key monitors: {e}", flush=True) - finally: - conn.close() - - -''' - -MARKET_OPEN_GUARD = ''' ok_src, src_msg = assert_open_source_allowed(POSITION_SIZING_MODE, OPEN_SOURCE_KEY_AUTO) - if not ok_src: - return False, src_msg, None -''' - -ADD_KEY_GUARD = ''' if is_full_margin_mode(POSITION_SIZING_MODE) and monitor_type_disallowed_in_full_margin(mt): - flash( - "全仓杠杆模式下不可添加箱体/收敛突破或斐波监控;" - "请改用阻力/支撑(仅提醒),或切换 POSITION_SIZING_MODE=risk 并重启(须无持仓)。" - ) - return redirect("/key_monitor") -''' - -TEMPLATE_RULE = '''
- 计仓模式:{{ position_sizing_mode_label }}(仅 .env POSITION_SIZING_MODE,须无仓后重启) - {% if position_sizing_mode == 'full_margin' %} - |全仓:合约可用×{{ full_margin_buffer_ratio }},BTC/ETH {{ btc_leverage }}x、其它 {{ alt_leverage }}x,单仓;张数按交易所精度 - {% else %} - |以损定仓:风险 {{ risk_percent }}% - {% endif %} - |移动保本:下单可勾选关闭;开启时 {{ breakeven_rr_trigger }}R 触发(每 1R 阶梯上移),偏移 {{ breakeven_offset_pct }}% -
''' - -APPS = [ - ("crypto_monitor_okx", 4, "_market_open_for_key_monitor", True), - ("crypto_monitor_gate", 2, "_market_open_for_key_monitor", True), - ("crypto_monitor_gate_bot", 4, None, False), -] - - -def patch_app(app_dir: str, funds_dec: int, market_fn: str | None, has_fib: bool): - path = ROOT / app_dir / "app.py" - text = path.read_text(encoding="utf-8") - if "POSITION_SIZING_MODE" in text: - print(f"SKIP {app_dir}/app.py (already patched)") - return - if "from position_sizing_lib import" not in text: - anchor = "from key_monitor_lib import (" - if anchor not in text: - anchor = "from form_submit_lib import" - text = text.replace( - anchor, - IMPORT_BLOCK + "\n" + anchor, - 1, - ) - else: - text = text.replace(anchor, IMPORT_BLOCK + anchor, 1) - if "POSITION_SIZING_MODE = load_position_sizing_mode()" not in text: - text = text.replace( - "AUTO_TRANSFER_BJ_HOUR = int(os.getenv(\"AUTO_TRANSFER_BJ_HOUR\", \"8\"))\n", - "AUTO_TRANSFER_BJ_HOUR = int(os.getenv(\"AUTO_TRANSFER_BJ_HOUR\", \"8\"))\n" + ENV_LINE, - 1, - ) - if "_purge_key_monitors_if_full_margin" not in text: - text = text.replace("init_db()\n\n\ndef get_db():", "init_db()" + PURGE_FN + "\ndef get_db():", 1) - text = text.replace( - "install_strategy_trend(app,", - "_purge_key_monitors_if_full_margin()\n\ninstall_strategy_trend(app,", - 1, - ) - if market_fn and MARKET_OPEN_GUARD.strip() not in text: - text = text.replace( - f"def {market_fn}(\n", - f"def {market_fn}(\n", - 1, - ) - text = text.replace( - ' """\n 与手动', - MARKET_OPEN_GUARD + ' """\n 与手动', - 1, - ) - # fallback: after docstring closing - if MARKET_OPEN_GUARD.strip() not in text: - pat = rf"(def {market_fn}\([^)]+\):\s*\n\s*\"\"\"[^\"\"]*\"\"\"\s*\n)" - text = re.sub(pat, r"\1" + MARKET_OPEN_GUARD, text, count=1) - if has_fib and ADD_KEY_GUARD.strip() not in text: - text = text.replace( - ' if mt not in allowed_types:', - ADD_KEY_GUARD + ' if mt not in allowed_types:', - 1, - ) if "if mt not in allowed_types:" in text else text.replace( - ' rank, total = _daily_volume_rank(symbol)', - ADD_KEY_GUARD + ' rank, total = _daily_volume_rank(symbol)', - 1, - ) - # render_template risk_percent= add template vars - if "position_sizing_mode=POSITION_SIZING_MODE" not in text: - text = text.replace( - "risk_percent=RISK_PERCENT,\n", - "risk_percent=RISK_PERCENT,\n" - " position_sizing_mode=POSITION_SIZING_MODE,\n" - " position_sizing_mode_label=mode_label_zh(POSITION_SIZING_MODE),\n" - " open_position_button_label=(\n" - ' "开仓(全仓杠杆)" if is_full_margin_mode(POSITION_SIZING_MODE) else "开仓(以损定仓)"\n' - " ),\n", - 1, - ) - path.write_text(text, encoding="utf-8") - print(f"DONE {app_dir}/app.py (partial — verify add_order block manually if needed)") - - -def patch_template(app_dir: str): - tpl = ROOT / app_dir / "templates" / "index.html" - if not tpl.exists(): - return - text = tpl.read_text(encoding="utf-8") - if "position_sizing_mode_label" in text: - print(f"SKIP {tpl}") - return - old = re.search( - r'
\s*以损定仓:风险 \{\{ risk_percent \}\}%.*?
', - text, - re.S, - ) - if old: - text = text[: old.start()] + TEMPLATE_RULE + text[old.end() :] - text = text.replace( - '', - '', - ) - text = text.replace( - '', - '{% if position_sizing_mode != \'full_margin\' %}\n' - ' \n' - ' {% endif %}', - 1, - ) - tpl.write_text(text, encoding="utf-8") - print(f"DONE {tpl}") - - -def main(): - for app_dir, funds, mfn, fib in APPS: - patch_app(app_dir, funds, mfn, fib) - patch_template(app_dir) - - -if __name__ == "__main__": - main() +#!/usr/bin/env python3 +"""一次性:为 okx/gate/gate_bot 注入与 binance 一致的计仓模式补丁(已 patch 过则跳过)。""" +from __future__ import annotations + +import re +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] + +IMPORT_BLOCK = '''from position_sizing_lib import ( + OPEN_SOURCE_KEY_AUTO, + OPEN_SOURCE_MANUAL, + assert_open_source_allowed, + compute_full_margin_sizing, + full_margin_requires_flat_position, + is_full_margin_mode, + leverage_for_full_margin, + load_position_sizing_mode, + mode_label_zh, +) +from lib.key_monitor.key_monitor_full_margin_lib import ( + monitor_type_disallowed_in_full_margin, + purge_disallowed_key_monitors, +) +''' + +ENV_LINE = ( + "# 计仓模式:risk=以损定仓(默认);full_margin=合约可用×比例全仓杠杆(仅 env 切换,须无仓)\n" + "POSITION_SIZING_MODE = load_position_sizing_mode()\n" +) + +PURGE_FN = ''' + +def _purge_key_monitors_if_full_margin(): + if not is_full_margin_mode(POSITION_SIZING_MODE): + return + conn = get_db() + try: + cancel = globals().get("_cancel_fib_monitor_limit") + if not callable(cancel): + cancel = lambda _row: None + purge_disallowed_key_monitors( + conn, + sizing_mode=POSITION_SIZING_MODE, + select_rows=lambda c: c.execute("SELECT * FROM key_monitors").fetchall(), + cancel_fib_limit=cancel, + delete_monitor=lambda c, kid: c.execute("DELETE FROM key_monitors WHERE id=?", (kid,)), + send_wechat=send_wechat_msg, + ) + conn.commit() + except Exception as e: + print(f"[full_margin] purge key monitors: {e}", flush=True) + finally: + conn.close() + + +''' + +MARKET_OPEN_GUARD = ''' ok_src, src_msg = assert_open_source_allowed(POSITION_SIZING_MODE, OPEN_SOURCE_KEY_AUTO) + if not ok_src: + return False, src_msg, None +''' + +ADD_KEY_GUARD = ''' if is_full_margin_mode(POSITION_SIZING_MODE) and monitor_type_disallowed_in_full_margin(mt): + flash( + "全仓杠杆模式下不可添加箱体/收敛突破或斐波监控;" + "请改用阻力/支撑(仅提醒),或切换 POSITION_SIZING_MODE=risk 并重启(须无持仓)。" + ) + return redirect("/key_monitor") +''' + +TEMPLATE_RULE = '''
+ 计仓模式:{{ position_sizing_mode_label }}(仅 .env POSITION_SIZING_MODE,须无仓后重启) + {% if position_sizing_mode == 'full_margin' %} + |全仓:合约可用×{{ full_margin_buffer_ratio }},BTC/ETH {{ btc_leverage }}x、其它 {{ alt_leverage }}x,单仓;张数按交易所精度 + {% else %} + |以损定仓:风险 {{ risk_percent }}% + {% endif %} + |移动保本:下单可勾选关闭;开启时 {{ breakeven_rr_trigger }}R 触发(每 1R 阶梯上移),偏移 {{ breakeven_offset_pct }}% +
''' + +APPS = [ + ("crypto_monitor_okx", 4, "_market_open_for_key_monitor", True), + ("crypto_monitor_gate", 2, "_market_open_for_key_monitor", True), + ("crypto_monitor_gate_bot", 4, None, False), +] + + +def patch_app(app_dir: str, funds_dec: int, market_fn: str | None, has_fib: bool): + path = ROOT / app_dir / "app.py" + text = path.read_text(encoding="utf-8") + if "POSITION_SIZING_MODE" in text: + print(f"SKIP {app_dir}/app.py (already patched)") + return + if "from position_sizing_lib import" not in text: + anchor = "from key_monitor_lib import (" + if anchor not in text: + anchor = "from form_submit_lib import" + text = text.replace( + anchor, + IMPORT_BLOCK + "\n" + anchor, + 1, + ) + else: + text = text.replace(anchor, IMPORT_BLOCK + anchor, 1) + if "POSITION_SIZING_MODE = load_position_sizing_mode()" not in text: + text = text.replace( + "AUTO_TRANSFER_BJ_HOUR = int(os.getenv(\"AUTO_TRANSFER_BJ_HOUR\", \"8\"))\n", + "AUTO_TRANSFER_BJ_HOUR = int(os.getenv(\"AUTO_TRANSFER_BJ_HOUR\", \"8\"))\n" + ENV_LINE, + 1, + ) + if "_purge_key_monitors_if_full_margin" not in text: + text = text.replace("init_db()\n\n\ndef get_db():", "init_db()" + PURGE_FN + "\ndef get_db():", 1) + text = text.replace( + "install_strategy_trend(app,", + "_purge_key_monitors_if_full_margin()\n\ninstall_strategy_trend(app,", + 1, + ) + if market_fn and MARKET_OPEN_GUARD.strip() not in text: + text = text.replace( + f"def {market_fn}(\n", + f"def {market_fn}(\n", + 1, + ) + text = text.replace( + ' """\n 与手动', + MARKET_OPEN_GUARD + ' """\n 与手动', + 1, + ) + # fallback: after docstring closing + if MARKET_OPEN_GUARD.strip() not in text: + pat = rf"(def {market_fn}\([^)]+\):\s*\n\s*\"\"\"[^\"\"]*\"\"\"\s*\n)" + text = re.sub(pat, r"\1" + MARKET_OPEN_GUARD, text, count=1) + if has_fib and ADD_KEY_GUARD.strip() not in text: + text = text.replace( + ' if mt not in allowed_types:', + ADD_KEY_GUARD + ' if mt not in allowed_types:', + 1, + ) if "if mt not in allowed_types:" in text else text.replace( + ' rank, total = _daily_volume_rank(symbol)', + ADD_KEY_GUARD + ' rank, total = _daily_volume_rank(symbol)', + 1, + ) + # render_template risk_percent= add template vars + if "position_sizing_mode=POSITION_SIZING_MODE" not in text: + text = text.replace( + "risk_percent=RISK_PERCENT,\n", + "risk_percent=RISK_PERCENT,\n" + " position_sizing_mode=POSITION_SIZING_MODE,\n" + " position_sizing_mode_label=mode_label_zh(POSITION_SIZING_MODE),\n" + " open_position_button_label=(\n" + ' "开仓(全仓杠杆)" if is_full_margin_mode(POSITION_SIZING_MODE) else "开仓(以损定仓)"\n' + " ),\n", + 1, + ) + path.write_text(text, encoding="utf-8") + print(f"DONE {app_dir}/app.py (partial — verify add_order block manually if needed)") + + +def patch_template(app_dir: str): + tpl = ROOT / app_dir / "templates" / "index.html" + if not tpl.exists(): + return + text = tpl.read_text(encoding="utf-8") + if "position_sizing_mode_label" in text: + print(f"SKIP {tpl}") + return + old = re.search( + r'
\s*以损定仓:风险 \{\{ risk_percent \}\}%.*?
', + text, + re.S, + ) + if old: + text = text[: old.start()] + TEMPLATE_RULE + text[old.end() :] + text = text.replace( + '', + '', + ) + text = text.replace( + '', + '{% if position_sizing_mode != \'full_margin\' %}\n' + ' \n' + ' {% endif %}', + 1, + ) + tpl.write_text(text, encoding="utf-8") + print(f"DONE {tpl}") + + +def main(): + for app_dir, funds, mfn, fib in APPS: + patch_app(app_dir, funds, mfn, fib) + patch_template(app_dir) + + +if __name__ == "__main__": + main() diff --git a/tests/test_account_risk_lib.py b/tests/test_account_risk_lib.py index 9882f7d..f3fd22d 100644 --- a/tests/test_account_risk_lib.py +++ b/tests/test_account_risk_lib.py @@ -1,526 +1,526 @@ -import os -import sqlite3 -import unittest -from datetime import datetime -from unittest import mock -from zoneinfo import ZoneInfo - -from account_risk_lib import ( - CLOSE_SOURCE_USER_HUB, - CLOSE_SOURCE_USER_INSTANCE, - CLOSE_SOURCE_USER_TREND_STOP, - STATUS_DAILY, - STATUS_FREEZE_1H, - STATUS_FREEZE_4H, - STATUS_FREEZE_POSITION, - STATUS_NORMAL, - account_risk_blocks_trading, - apply_position_limit_risk, - compute_account_risk_status, - enrich_risk_status_countdown, - ensure_account_risk_schema, - max_active_positions_from_env, - on_journal_saved, - on_manual_close, - on_user_initiated_close, - parse_mood_issues, -) - -APP_TZ = ZoneInfo("Asia/Shanghai") - - -def _mem_conn(): - conn = sqlite3.connect(":memory:") - conn.row_factory = sqlite3.Row - ensure_account_risk_schema(conn) - return conn - - -def _mem_conn_with_journal(): - conn = _mem_conn() - conn.execute( - """CREATE TABLE IF NOT EXISTS journal_entries ( - close_datetime TEXT, early_exit_trigger TEXT, early_exit_note TEXT - )""" - ) - return conn - - -def _local_ms(dt_naive: datetime) -> int: - return int(dt_naive.replace(tzinfo=APP_TZ).timestamp() * 1000) - - -class AccountRiskLibTests(unittest.TestCase): - def setUp(self): - self.env_patch = mock.patch.dict(os.environ, {}, clear=False) - self.env_patch.start() - os.environ["RISK_CONTROL_ENABLED"] = "1" - os.environ["RISK_COOLING_HOURS_MANUAL"] = "4" - os.environ["RISK_COOLING_HOURS_MANUAL_JOURNAL"] = "1" - os.environ["RISK_MANUAL_CLOSE_DAILY_LIMIT"] = "2" - os.environ["RISK_MOOD_ISSUES_DAILY_FREEZE"] = "1" - os.environ["APP_TIMEZONE"] = "Asia/Shanghai" - - def tearDown(self): - self.env_patch.stop() - - def test_user_instance_sets_4h_cooloff(self): - conn = _mem_conn() - now = datetime(2026, 6, 14, 12, 0, 0) - close_ms = _local_ms(now) - on_user_initiated_close( - conn, - source=CLOSE_SOURCE_USER_INSTANCE, - trade_record_id=101, - closed_at_ms=close_ms, - trading_day="2026-06-14", - now=now, - ) - st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) - self.assertEqual(st["status"], STATUS_FREEZE_4H) - self.assertFalse(st["can_trade"]) - self.assertAlmostEqual(st["freeze_remaining_sec"], 4 * 3600, delta=2) - - def test_invalid_source_ignored(self): - conn = _mem_conn() - now = datetime(2026, 6, 14, 12, 0, 0) - on_user_initiated_close( - conn, - source="exchange_tpsl", - trading_day="2026-06-14", - now=now, - ) - st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) - self.assertEqual(st["status"], STATUS_NORMAL) - - def test_second_user_close_daily_freeze(self): - conn = _mem_conn() - now = datetime(2026, 6, 14, 12, 0, 0) - close_ms = _local_ms(now) - on_user_initiated_close( - conn, source=CLOSE_SOURCE_USER_HUB, closed_at_ms=close_ms, trading_day="2026-06-14", now=now - ) - on_user_initiated_close( - conn, source=CLOSE_SOURCE_USER_HUB, closed_at_ms=close_ms + 1000, trading_day="2026-06-14", now=now - ) - st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) - self.assertEqual(st["status"], STATUS_DAILY) - - def test_hub_close_all_count(self): - conn = _mem_conn() - now = datetime(2026, 6, 14, 12, 0, 0) - close_ms = _local_ms(now) - on_user_initiated_close( - conn, - source=CLOSE_SOURCE_USER_HUB, - closed_at_ms=close_ms, - trading_day="2026-06-14", - now=now, - count=2, - ) - st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) - self.assertEqual(st["manual_close_count"], 2) - self.assertEqual(st["status"], STATUS_DAILY) - - def test_trend_stop_counts_as_manual(self): - conn = _mem_conn() - now = datetime(2026, 6, 14, 12, 0, 0) - on_user_initiated_close( - conn, - source=CLOSE_SOURCE_USER_TREND_STOP, - trading_day="2026-06-14", - now=now, - ) - st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) - self.assertEqual(st["manual_close_count"], 1) - self.assertEqual(st["status"], STATUS_FREEZE_4H) - self.assertAlmostEqual(st["freeze_remaining_sec"], 4 * 3600, delta=2) - - def test_journal_manual_with_note_reduces_to_1h(self): - conn = _mem_conn() - now = datetime(2026, 6, 14, 12, 0, 0) - close_ms = _local_ms(now) - on_manual_close(conn, trade_record_id=9, closed_at_ms=close_ms, trading_day="2026-06-14", now=now) - on_journal_saved( - conn, - early_exit_trigger="手动平仓", - early_exit_note="违反计划提前离场", - mood_issues_raw="", - trading_day="2026-06-14", - now=now, - ) - st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) - self.assertEqual(st["status"], STATUS_FREEZE_1H) - self.assertAlmostEqual(st["freeze_remaining_sec"], 3600, delta=2) - - def test_journal_hub_close_without_pending_reduces_to_1h(self): - conn = _mem_conn() - now = datetime(2026, 6, 14, 12, 0, 0) - close_ms = _local_ms(now) - on_user_initiated_close( - conn, - source=CLOSE_SOURCE_USER_HUB, - closed_at_ms=close_ms, - trading_day="2026-06-14", - now=now, - ) - on_journal_saved( - conn, - early_exit_trigger="手动平仓", - early_exit_note="中控全平后复盘说明", - mood_issues_raw="", - trading_day="2026-06-14", - now=now, - ) - st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) - self.assertEqual(st["status"], STATUS_FREEZE_1H) - - def test_journal_reduces_when_manual_count_cleared_but_cooloff_active(self): - conn = _mem_conn() - now = datetime(2026, 6, 15, 10, 0, 0) - now_ms = _local_ms(now) - close_ms = now_ms - 3600 * 1000 - until_ms = close_ms + 4 * 3600 * 1000 - conn.execute( - """UPDATE account_risk_state SET - trading_day='2026-06-15', - manual_close_count=0, - cooloff_until_ms=?, - cooloff_hours=4, - last_close_at_ms=?, - daily_frozen=0 - WHERE id=1""", - (until_ms, close_ms), - ) - on_journal_saved( - conn, - early_exit_trigger="手动平仓", - early_exit_note="切日后补复盘", - mood_issues_raw="", - trading_day="2026-06-15", - now=now, - ) - st = compute_account_risk_status(conn, trading_day="2026-06-15", now=now) - self.assertEqual(st["status"], STATUS_FREEZE_1H) - - def test_journal_late_save_still_gets_1h_from_now(self): - conn = _mem_conn() - close_at = datetime(2026, 6, 14, 12, 0, 0) - close_ms = _local_ms(close_at) - on_user_initiated_close( - conn, - source=CLOSE_SOURCE_USER_INSTANCE, - closed_at_ms=close_ms, - trading_day="2026-06-14", - now=close_at, - ) - journal_at = datetime(2026, 6, 14, 14, 0, 0) - on_journal_saved( - conn, - early_exit_trigger="手动平仓", - early_exit_note="补写复盘说明", - mood_issues_raw="", - trading_day="2026-06-14", - now=journal_at, - ) - st = compute_account_risk_status(conn, trading_day="2026-06-14", now=journal_at) - self.assertEqual(st["status"], STATUS_FREEZE_1H) - self.assertEqual(st["cooloff_until_ms"], _local_ms(journal_at) + 3600 * 1000) - - def test_stale_4h_until_with_1h_hours_uses_shorter_end(self): - """库内 cooloff_hours=1 但 cooloff_until_ms 仍为旧 4h 时,应按 last_close+1h 倒计时。""" - conn = _mem_conn() - now = datetime(2026, 6, 14, 12, 6, 0) - now_ms = _local_ms(now) - close_ms = now_ms - 6 * 60 * 1000 - stale_until_4h = close_ms + 4 * 3600 * 1000 - conn.execute( - """UPDATE account_risk_state SET - trading_day='2026-06-14', - manual_close_count=1, - cooloff_until_ms=?, - cooloff_hours=1, - last_close_at_ms=?, - daily_frozen=0 - WHERE id=1""", - (stale_until_4h, close_ms), - ) - st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) - self.assertEqual(st["status"], STATUS_FREEZE_1H) - self.assertAlmostEqual(st["freeze_remaining_sec"], 54 * 60, delta=3) - - def test_stale_4h_ignored_after_1h_journal_expired(self): - """复盘已降为 1h 且窗口结束后,不应再读库内旧 4h until。""" - conn = _mem_conn() - close_at = datetime(2026, 6, 18, 17, 56, 0) - now = datetime(2026, 6, 18, 21, 50, 0) - close_ms = _local_ms(close_at) - stale_4h_until = close_ms + 4 * 3600 * 1000 - conn.execute( - """UPDATE account_risk_state SET - trading_day='2026-06-18', - manual_close_count=1, - cooloff_until_ms=?, - cooloff_hours=1, - last_close_at_ms=?, - daily_frozen=0 - WHERE id=1""", - (stale_4h_until, close_ms), - ) - st = compute_account_risk_status(conn, trading_day="2026-06-18", now=now) - self.assertEqual(st["status"], STATUS_NORMAL) - self.assertTrue(st["can_trade"]) - row = conn.execute( - "SELECT cooloff_until_ms, cooloff_hours, last_close_at_ms FROM account_risk_state WHERE id=1" - ).fetchone() - self.assertIsNone(row["cooloff_until_ms"]) - self.assertIsNone(row["last_close_at_ms"]) - - def test_corrupted_anchor_cleared_when_journaled_manual_expired(self): - """上一版误把 last_close 写成近期时刻时,已复盘且 1h 已过的仍应显示正常。""" - conn = _mem_conn_with_journal() - now = datetime(2026, 6, 18, 22, 30, 0) - now_ms = _local_ms(now) - bad_last = now_ms - 60 * 1000 - conn.execute( - """UPDATE account_risk_state SET - trading_day='2026-06-18', - manual_close_count=1, - cooloff_until_ms=?, - cooloff_hours=1, - last_close_at_ms=?, - pending_journal_trade_id=NULL, - daily_frozen=0 - WHERE id=1""", - (bad_last + 3600 * 1000, bad_last), - ) - conn.execute( - "INSERT INTO journal_entries (close_datetime, early_exit_trigger, early_exit_note) VALUES (?,?,?)", - ("2026-06-18 17:56:00", "手动平仓", "按计划离场"), - ) - st = compute_account_risk_status(conn, trading_day="2026-06-18", now=now) - self.assertEqual(st["status"], STATUS_NORMAL) - self.assertTrue(st["can_trade"]) - - def test_future_last_close_does_not_restart_cooloff(self): - """脏数据 last_close 在未来时,不应重启 1h/4h 冻结。""" - conn = _mem_conn() - now = datetime(2026, 6, 18, 22, 30, 0) - now_ms = _local_ms(now) - future_close = now_ms + 49 * 60 * 1000 - conn.execute( - """UPDATE account_risk_state SET - trading_day='2026-06-18', - manual_close_count=1, - cooloff_until_ms=?, - cooloff_hours=1, - last_close_at_ms=?, - daily_frozen=0 - WHERE id=1""", - (future_close + 3600 * 1000, future_close), - ) - st = compute_account_risk_status(conn, trading_day="2026-06-18", now=now) - self.assertEqual(st["status"], STATUS_NORMAL) - self.assertTrue(st["can_trade"]) - - def test_active_4h_countdown_matches_tier(self): - conn = _mem_conn() - close_at = datetime(2026, 6, 18, 21, 46, 0) - now = datetime(2026, 6, 18, 21, 52, 0) - close_ms = _local_ms(close_at) - on_user_initiated_close( - conn, - source=CLOSE_SOURCE_USER_INSTANCE, - closed_at_ms=close_ms, - trading_day="2026-06-18", - now=close_at, - ) - st = compute_account_risk_status(conn, trading_day="2026-06-18", now=now) - self.assertEqual(st["status"], STATUS_FREEZE_4H) - self.assertAlmostEqual(st["freeze_remaining_sec"], 3 * 3600 + 54 * 60, delta=5) - - def test_trading_day_reset_clears_expired_stale_cooloff(self): - conn = _mem_conn() - close_at = datetime(2026, 6, 18, 17, 56, 0) - close_ms = _local_ms(close_at) - stale_4h_until = close_ms + 4 * 3600 * 1000 - conn.execute( - """UPDATE account_risk_state SET - trading_day='2026-06-18', - manual_close_count=1, - cooloff_until_ms=?, - cooloff_hours=1, - last_close_at_ms=?, - daily_frozen=0 - WHERE id=1""", - (stale_4h_until, close_ms), - ) - next_day = datetime(2026, 6, 19, 9, 0, 0) - st = compute_account_risk_status(conn, trading_day="2026-06-19", now=next_day) - self.assertEqual(st["status"], STATUS_NORMAL) - row = conn.execute("SELECT cooloff_until_ms FROM account_risk_state WHERE id=1").fetchone() - self.assertIsNone(row["cooloff_until_ms"]) - - def test_remaining_never_exceeds_configured_hours(self): - conn = _mem_conn() - now = datetime(2026, 6, 18, 22, 0, 0) - now_ms = _local_ms(now) - future_close = now_ms + 49 * 60 * 1000 - conn.execute( - """UPDATE account_risk_state SET - trading_day='2026-06-18', - manual_close_count=1, - cooloff_until_ms=?, - cooloff_hours=4, - last_close_at_ms=?, - daily_frozen=0 - WHERE id=1""", - (future_close + 4 * 3600 * 1000, future_close), - ) - st = compute_account_risk_status(conn, trading_day="2026-06-18", now=now) - self.assertEqual(st["status"], STATUS_NORMAL) - self.assertTrue(st["can_trade"]) - - def test_legacy_naive_utc_ms_countdown_normalized(self): - conn = _mem_conn() - now = datetime(2026, 6, 14, 12, 0, 0) - now_ms = _local_ms(now) - offset_ms = 8 * 3600 * 1000 - legacy_close = now_ms + offset_ms - legacy_until = legacy_close + 4 * 3600 * 1000 - conn.execute( - """UPDATE account_risk_state SET - trading_day='2026-06-14', - manual_close_count=1, - cooloff_until_ms=?, - cooloff_hours=4, - last_close_at_ms=?, - daily_frozen=0 - WHERE id=1""", - (legacy_until, legacy_close), - ) - st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) - st = enrich_risk_status_countdown(st, now=now, daily_reset_hour=8) - self.assertEqual(st["status"], STATUS_FREEZE_4H) - self.assertAlmostEqual(st["freeze_remaining_sec"], 4 * 3600, delta=2) - - def test_journal_mood_issues_daily_freeze(self): - conn = _mem_conn() - now = datetime(2026, 6, 14, 12, 0, 0) - on_journal_saved( - conn, - early_exit_trigger="止损", - early_exit_note="", - mood_issues_raw=["报复开仓"], - trading_day="2026-06-14", - now=now, - ) - st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) - self.assertEqual(st["status"], STATUS_DAILY) - - def test_cooloff_expired_returns_normal(self): - conn = _mem_conn() - start = datetime(2026, 6, 14, 8, 0, 0) - close_ms = _local_ms(start) - on_user_initiated_close( - conn, source=CLOSE_SOURCE_USER_INSTANCE, closed_at_ms=close_ms, trading_day="2026-06-14", now=start - ) - later = datetime(2026, 6, 14, 13, 0, 0) - st = compute_account_risk_status(conn, trading_day="2026-06-14", now=later) - self.assertEqual(st["status"], STATUS_NORMAL) - row = conn.execute("SELECT cooloff_until_ms FROM account_risk_state WHERE id=1").fetchone() - self.assertIsNone(row["cooloff_until_ms"]) - - def test_trading_day_reset_clears_daily_frozen(self): - conn = _mem_conn() - now = datetime(2026, 6, 14, 12, 0, 0) - on_journal_saved( - conn, - early_exit_trigger="止损", - early_exit_note="", - mood_issues_raw="扛单", - trading_day="2026-06-14", - now=now, - ) - next_day = datetime(2026, 6, 15, 8, 0, 0) - st = compute_account_risk_status(conn, trading_day="2026-06-15", now=next_day) - self.assertEqual(st["status"], STATUS_NORMAL) - - def test_parse_mood_issues_filters_unknown(self): - self.assertEqual(parse_mood_issues("怕踏空,未知标签,扛单"), ["怕踏空", "扛单"]) - - def test_enrich_countdown_for_daily_and_cooloff(self): - conn = _mem_conn() - now = datetime(2026, 6, 14, 12, 0, 0) - close_ms = _local_ms(now) - on_user_initiated_close( - conn, - source=CLOSE_SOURCE_USER_INSTANCE, - closed_at_ms=close_ms, - trading_day="2026-06-14", - now=now, - ) - st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) - st = enrich_risk_status_countdown(st, now=now, daily_reset_hour=8) - self.assertGreater(st["freeze_remaining_sec"], 0) - self.assertEqual(st["freeze_until_ms"], st["cooloff_until_ms"]) - - on_journal_saved( - conn, - early_exit_trigger="止损", - early_exit_note="", - mood_issues_raw="扛单", - trading_day="2026-06-14", - now=now, - ) - st2 = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) - st2 = enrich_risk_status_countdown(st2, now=now, daily_reset_hour=8) - self.assertTrue(st2["daily_frozen"]) - self.assertGreater(st2["freeze_remaining_sec"], 0) - self.assertIsNotNone(st2["freeze_until_ms"]) - - def test_disabled_risk_control(self): - os.environ["RISK_CONTROL_ENABLED"] = "0" - conn = _mem_conn() - now = datetime(2026, 6, 14, 12, 0, 0) - on_user_initiated_close( - conn, source=CLOSE_SOURCE_USER_INSTANCE, trading_day="2026-06-14", now=now - ) - st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) - self.assertFalse(st["enabled"]) - self.assertTrue(st["can_trade"]) - ok, _ = account_risk_blocks_trading(conn, trading_day="2026-06-14", now=now) - self.assertTrue(ok) - - def test_position_limit_freeze_from_env(self): - os.environ["MAX_ACTIVE_POSITIONS"] = "2" - st = apply_position_limit_risk({"status": STATUS_NORMAL, "can_trade": True}, 2) - self.assertEqual(st["status"], STATUS_FREEZE_POSITION) - self.assertEqual(st["status_label"], "仓位上限冻结") - self.assertFalse(st["can_trade"]) - self.assertIn("2/2", st["reason"]) - self.assertIn("顺势加仓", st["reason"]) - self.assertTrue(st.get("can_roll")) - self.assertEqual(st["max_active_positions"], 2) - - def test_position_limit_normal_when_under_cap(self): - st = apply_position_limit_risk({"status": STATUS_NORMAL, "can_trade": True}, 0, max_active_positions=1) - self.assertEqual(st["status"], STATUS_NORMAL) - self.assertTrue(st["can_trade"]) - - def test_time_freeze_takes_priority_over_position_limit(self): - st = apply_position_limit_risk( - {"status": STATUS_FREEZE_4H, "status_label": "4h冻结", "can_trade": False}, - 5, - max_active_positions=1, - ) - self.assertEqual(st["status"], STATUS_FREEZE_4H) - self.assertEqual(st["active_count"], 5) - - def test_max_active_positions_from_env(self): - os.environ["MAX_ACTIVE_POSITIONS"] = "3" - self.assertEqual(max_active_positions_from_env(), 3) - - -if __name__ == "__main__": - unittest.main() +import os +import sqlite3 +import unittest +from datetime import datetime +from unittest import mock +from zoneinfo import ZoneInfo + +from lib.trade.account_risk_lib import ( + CLOSE_SOURCE_USER_HUB, + CLOSE_SOURCE_USER_INSTANCE, + CLOSE_SOURCE_USER_TREND_STOP, + STATUS_DAILY, + STATUS_FREEZE_1H, + STATUS_FREEZE_4H, + STATUS_FREEZE_POSITION, + STATUS_NORMAL, + account_risk_blocks_trading, + apply_position_limit_risk, + compute_account_risk_status, + enrich_risk_status_countdown, + ensure_account_risk_schema, + max_active_positions_from_env, + on_journal_saved, + on_manual_close, + on_user_initiated_close, + parse_mood_issues, +) + +APP_TZ = ZoneInfo("Asia/Shanghai") + + +def _mem_conn(): + conn = sqlite3.connect(":memory:") + conn.row_factory = sqlite3.Row + ensure_account_risk_schema(conn) + return conn + + +def _mem_conn_with_journal(): + conn = _mem_conn() + conn.execute( + """CREATE TABLE IF NOT EXISTS journal_entries ( + close_datetime TEXT, early_exit_trigger TEXT, early_exit_note TEXT + )""" + ) + return conn + + +def _local_ms(dt_naive: datetime) -> int: + return int(dt_naive.replace(tzinfo=APP_TZ).timestamp() * 1000) + + +class AccountRiskLibTests(unittest.TestCase): + def setUp(self): + self.env_patch = mock.patch.dict(os.environ, {}, clear=False) + self.env_patch.start() + os.environ["RISK_CONTROL_ENABLED"] = "1" + os.environ["RISK_COOLING_HOURS_MANUAL"] = "4" + os.environ["RISK_COOLING_HOURS_MANUAL_JOURNAL"] = "1" + os.environ["RISK_MANUAL_CLOSE_DAILY_LIMIT"] = "2" + os.environ["RISK_MOOD_ISSUES_DAILY_FREEZE"] = "1" + os.environ["APP_TIMEZONE"] = "Asia/Shanghai" + + def tearDown(self): + self.env_patch.stop() + + def test_user_instance_sets_4h_cooloff(self): + conn = _mem_conn() + now = datetime(2026, 6, 14, 12, 0, 0) + close_ms = _local_ms(now) + on_user_initiated_close( + conn, + source=CLOSE_SOURCE_USER_INSTANCE, + trade_record_id=101, + closed_at_ms=close_ms, + trading_day="2026-06-14", + now=now, + ) + st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) + self.assertEqual(st["status"], STATUS_FREEZE_4H) + self.assertFalse(st["can_trade"]) + self.assertAlmostEqual(st["freeze_remaining_sec"], 4 * 3600, delta=2) + + def test_invalid_source_ignored(self): + conn = _mem_conn() + now = datetime(2026, 6, 14, 12, 0, 0) + on_user_initiated_close( + conn, + source="exchange_tpsl", + trading_day="2026-06-14", + now=now, + ) + st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) + self.assertEqual(st["status"], STATUS_NORMAL) + + def test_second_user_close_daily_freeze(self): + conn = _mem_conn() + now = datetime(2026, 6, 14, 12, 0, 0) + close_ms = _local_ms(now) + on_user_initiated_close( + conn, source=CLOSE_SOURCE_USER_HUB, closed_at_ms=close_ms, trading_day="2026-06-14", now=now + ) + on_user_initiated_close( + conn, source=CLOSE_SOURCE_USER_HUB, closed_at_ms=close_ms + 1000, trading_day="2026-06-14", now=now + ) + st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) + self.assertEqual(st["status"], STATUS_DAILY) + + def test_hub_close_all_count(self): + conn = _mem_conn() + now = datetime(2026, 6, 14, 12, 0, 0) + close_ms = _local_ms(now) + on_user_initiated_close( + conn, + source=CLOSE_SOURCE_USER_HUB, + closed_at_ms=close_ms, + trading_day="2026-06-14", + now=now, + count=2, + ) + st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) + self.assertEqual(st["manual_close_count"], 2) + self.assertEqual(st["status"], STATUS_DAILY) + + def test_trend_stop_counts_as_manual(self): + conn = _mem_conn() + now = datetime(2026, 6, 14, 12, 0, 0) + on_user_initiated_close( + conn, + source=CLOSE_SOURCE_USER_TREND_STOP, + trading_day="2026-06-14", + now=now, + ) + st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) + self.assertEqual(st["manual_close_count"], 1) + self.assertEqual(st["status"], STATUS_FREEZE_4H) + self.assertAlmostEqual(st["freeze_remaining_sec"], 4 * 3600, delta=2) + + def test_journal_manual_with_note_reduces_to_1h(self): + conn = _mem_conn() + now = datetime(2026, 6, 14, 12, 0, 0) + close_ms = _local_ms(now) + on_manual_close(conn, trade_record_id=9, closed_at_ms=close_ms, trading_day="2026-06-14", now=now) + on_journal_saved( + conn, + early_exit_trigger="手动平仓", + early_exit_note="违反计划提前离场", + mood_issues_raw="", + trading_day="2026-06-14", + now=now, + ) + st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) + self.assertEqual(st["status"], STATUS_FREEZE_1H) + self.assertAlmostEqual(st["freeze_remaining_sec"], 3600, delta=2) + + def test_journal_hub_close_without_pending_reduces_to_1h(self): + conn = _mem_conn() + now = datetime(2026, 6, 14, 12, 0, 0) + close_ms = _local_ms(now) + on_user_initiated_close( + conn, + source=CLOSE_SOURCE_USER_HUB, + closed_at_ms=close_ms, + trading_day="2026-06-14", + now=now, + ) + on_journal_saved( + conn, + early_exit_trigger="手动平仓", + early_exit_note="中控全平后复盘说明", + mood_issues_raw="", + trading_day="2026-06-14", + now=now, + ) + st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) + self.assertEqual(st["status"], STATUS_FREEZE_1H) + + def test_journal_reduces_when_manual_count_cleared_but_cooloff_active(self): + conn = _mem_conn() + now = datetime(2026, 6, 15, 10, 0, 0) + now_ms = _local_ms(now) + close_ms = now_ms - 3600 * 1000 + until_ms = close_ms + 4 * 3600 * 1000 + conn.execute( + """UPDATE account_risk_state SET + trading_day='2026-06-15', + manual_close_count=0, + cooloff_until_ms=?, + cooloff_hours=4, + last_close_at_ms=?, + daily_frozen=0 + WHERE id=1""", + (until_ms, close_ms), + ) + on_journal_saved( + conn, + early_exit_trigger="手动平仓", + early_exit_note="切日后补复盘", + mood_issues_raw="", + trading_day="2026-06-15", + now=now, + ) + st = compute_account_risk_status(conn, trading_day="2026-06-15", now=now) + self.assertEqual(st["status"], STATUS_FREEZE_1H) + + def test_journal_late_save_still_gets_1h_from_now(self): + conn = _mem_conn() + close_at = datetime(2026, 6, 14, 12, 0, 0) + close_ms = _local_ms(close_at) + on_user_initiated_close( + conn, + source=CLOSE_SOURCE_USER_INSTANCE, + closed_at_ms=close_ms, + trading_day="2026-06-14", + now=close_at, + ) + journal_at = datetime(2026, 6, 14, 14, 0, 0) + on_journal_saved( + conn, + early_exit_trigger="手动平仓", + early_exit_note="补写复盘说明", + mood_issues_raw="", + trading_day="2026-06-14", + now=journal_at, + ) + st = compute_account_risk_status(conn, trading_day="2026-06-14", now=journal_at) + self.assertEqual(st["status"], STATUS_FREEZE_1H) + self.assertEqual(st["cooloff_until_ms"], _local_ms(journal_at) + 3600 * 1000) + + def test_stale_4h_until_with_1h_hours_uses_shorter_end(self): + """库内 cooloff_hours=1 但 cooloff_until_ms 仍为旧 4h 时,应按 last_close+1h 倒计时。""" + conn = _mem_conn() + now = datetime(2026, 6, 14, 12, 6, 0) + now_ms = _local_ms(now) + close_ms = now_ms - 6 * 60 * 1000 + stale_until_4h = close_ms + 4 * 3600 * 1000 + conn.execute( + """UPDATE account_risk_state SET + trading_day='2026-06-14', + manual_close_count=1, + cooloff_until_ms=?, + cooloff_hours=1, + last_close_at_ms=?, + daily_frozen=0 + WHERE id=1""", + (stale_until_4h, close_ms), + ) + st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) + self.assertEqual(st["status"], STATUS_FREEZE_1H) + self.assertAlmostEqual(st["freeze_remaining_sec"], 54 * 60, delta=3) + + def test_stale_4h_ignored_after_1h_journal_expired(self): + """复盘已降为 1h 且窗口结束后,不应再读库内旧 4h until。""" + conn = _mem_conn() + close_at = datetime(2026, 6, 18, 17, 56, 0) + now = datetime(2026, 6, 18, 21, 50, 0) + close_ms = _local_ms(close_at) + stale_4h_until = close_ms + 4 * 3600 * 1000 + conn.execute( + """UPDATE account_risk_state SET + trading_day='2026-06-18', + manual_close_count=1, + cooloff_until_ms=?, + cooloff_hours=1, + last_close_at_ms=?, + daily_frozen=0 + WHERE id=1""", + (stale_4h_until, close_ms), + ) + st = compute_account_risk_status(conn, trading_day="2026-06-18", now=now) + self.assertEqual(st["status"], STATUS_NORMAL) + self.assertTrue(st["can_trade"]) + row = conn.execute( + "SELECT cooloff_until_ms, cooloff_hours, last_close_at_ms FROM account_risk_state WHERE id=1" + ).fetchone() + self.assertIsNone(row["cooloff_until_ms"]) + self.assertIsNone(row["last_close_at_ms"]) + + def test_corrupted_anchor_cleared_when_journaled_manual_expired(self): + """上一版误把 last_close 写成近期时刻时,已复盘且 1h 已过的仍应显示正常。""" + conn = _mem_conn_with_journal() + now = datetime(2026, 6, 18, 22, 30, 0) + now_ms = _local_ms(now) + bad_last = now_ms - 60 * 1000 + conn.execute( + """UPDATE account_risk_state SET + trading_day='2026-06-18', + manual_close_count=1, + cooloff_until_ms=?, + cooloff_hours=1, + last_close_at_ms=?, + pending_journal_trade_id=NULL, + daily_frozen=0 + WHERE id=1""", + (bad_last + 3600 * 1000, bad_last), + ) + conn.execute( + "INSERT INTO journal_entries (close_datetime, early_exit_trigger, early_exit_note) VALUES (?,?,?)", + ("2026-06-18 17:56:00", "手动平仓", "按计划离场"), + ) + st = compute_account_risk_status(conn, trading_day="2026-06-18", now=now) + self.assertEqual(st["status"], STATUS_NORMAL) + self.assertTrue(st["can_trade"]) + + def test_future_last_close_does_not_restart_cooloff(self): + """脏数据 last_close 在未来时,不应重启 1h/4h 冻结。""" + conn = _mem_conn() + now = datetime(2026, 6, 18, 22, 30, 0) + now_ms = _local_ms(now) + future_close = now_ms + 49 * 60 * 1000 + conn.execute( + """UPDATE account_risk_state SET + trading_day='2026-06-18', + manual_close_count=1, + cooloff_until_ms=?, + cooloff_hours=1, + last_close_at_ms=?, + daily_frozen=0 + WHERE id=1""", + (future_close + 3600 * 1000, future_close), + ) + st = compute_account_risk_status(conn, trading_day="2026-06-18", now=now) + self.assertEqual(st["status"], STATUS_NORMAL) + self.assertTrue(st["can_trade"]) + + def test_active_4h_countdown_matches_tier(self): + conn = _mem_conn() + close_at = datetime(2026, 6, 18, 21, 46, 0) + now = datetime(2026, 6, 18, 21, 52, 0) + close_ms = _local_ms(close_at) + on_user_initiated_close( + conn, + source=CLOSE_SOURCE_USER_INSTANCE, + closed_at_ms=close_ms, + trading_day="2026-06-18", + now=close_at, + ) + st = compute_account_risk_status(conn, trading_day="2026-06-18", now=now) + self.assertEqual(st["status"], STATUS_FREEZE_4H) + self.assertAlmostEqual(st["freeze_remaining_sec"], 3 * 3600 + 54 * 60, delta=5) + + def test_trading_day_reset_clears_expired_stale_cooloff(self): + conn = _mem_conn() + close_at = datetime(2026, 6, 18, 17, 56, 0) + close_ms = _local_ms(close_at) + stale_4h_until = close_ms + 4 * 3600 * 1000 + conn.execute( + """UPDATE account_risk_state SET + trading_day='2026-06-18', + manual_close_count=1, + cooloff_until_ms=?, + cooloff_hours=1, + last_close_at_ms=?, + daily_frozen=0 + WHERE id=1""", + (stale_4h_until, close_ms), + ) + next_day = datetime(2026, 6, 19, 9, 0, 0) + st = compute_account_risk_status(conn, trading_day="2026-06-19", now=next_day) + self.assertEqual(st["status"], STATUS_NORMAL) + row = conn.execute("SELECT cooloff_until_ms FROM account_risk_state WHERE id=1").fetchone() + self.assertIsNone(row["cooloff_until_ms"]) + + def test_remaining_never_exceeds_configured_hours(self): + conn = _mem_conn() + now = datetime(2026, 6, 18, 22, 0, 0) + now_ms = _local_ms(now) + future_close = now_ms + 49 * 60 * 1000 + conn.execute( + """UPDATE account_risk_state SET + trading_day='2026-06-18', + manual_close_count=1, + cooloff_until_ms=?, + cooloff_hours=4, + last_close_at_ms=?, + daily_frozen=0 + WHERE id=1""", + (future_close + 4 * 3600 * 1000, future_close), + ) + st = compute_account_risk_status(conn, trading_day="2026-06-18", now=now) + self.assertEqual(st["status"], STATUS_NORMAL) + self.assertTrue(st["can_trade"]) + + def test_legacy_naive_utc_ms_countdown_normalized(self): + conn = _mem_conn() + now = datetime(2026, 6, 14, 12, 0, 0) + now_ms = _local_ms(now) + offset_ms = 8 * 3600 * 1000 + legacy_close = now_ms + offset_ms + legacy_until = legacy_close + 4 * 3600 * 1000 + conn.execute( + """UPDATE account_risk_state SET + trading_day='2026-06-14', + manual_close_count=1, + cooloff_until_ms=?, + cooloff_hours=4, + last_close_at_ms=?, + daily_frozen=0 + WHERE id=1""", + (legacy_until, legacy_close), + ) + st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) + st = enrich_risk_status_countdown(st, now=now, daily_reset_hour=8) + self.assertEqual(st["status"], STATUS_FREEZE_4H) + self.assertAlmostEqual(st["freeze_remaining_sec"], 4 * 3600, delta=2) + + def test_journal_mood_issues_daily_freeze(self): + conn = _mem_conn() + now = datetime(2026, 6, 14, 12, 0, 0) + on_journal_saved( + conn, + early_exit_trigger="止损", + early_exit_note="", + mood_issues_raw=["报复开仓"], + trading_day="2026-06-14", + now=now, + ) + st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) + self.assertEqual(st["status"], STATUS_DAILY) + + def test_cooloff_expired_returns_normal(self): + conn = _mem_conn() + start = datetime(2026, 6, 14, 8, 0, 0) + close_ms = _local_ms(start) + on_user_initiated_close( + conn, source=CLOSE_SOURCE_USER_INSTANCE, closed_at_ms=close_ms, trading_day="2026-06-14", now=start + ) + later = datetime(2026, 6, 14, 13, 0, 0) + st = compute_account_risk_status(conn, trading_day="2026-06-14", now=later) + self.assertEqual(st["status"], STATUS_NORMAL) + row = conn.execute("SELECT cooloff_until_ms FROM account_risk_state WHERE id=1").fetchone() + self.assertIsNone(row["cooloff_until_ms"]) + + def test_trading_day_reset_clears_daily_frozen(self): + conn = _mem_conn() + now = datetime(2026, 6, 14, 12, 0, 0) + on_journal_saved( + conn, + early_exit_trigger="止损", + early_exit_note="", + mood_issues_raw="扛单", + trading_day="2026-06-14", + now=now, + ) + next_day = datetime(2026, 6, 15, 8, 0, 0) + st = compute_account_risk_status(conn, trading_day="2026-06-15", now=next_day) + self.assertEqual(st["status"], STATUS_NORMAL) + + def test_parse_mood_issues_filters_unknown(self): + self.assertEqual(parse_mood_issues("怕踏空,未知标签,扛单"), ["怕踏空", "扛单"]) + + def test_enrich_countdown_for_daily_and_cooloff(self): + conn = _mem_conn() + now = datetime(2026, 6, 14, 12, 0, 0) + close_ms = _local_ms(now) + on_user_initiated_close( + conn, + source=CLOSE_SOURCE_USER_INSTANCE, + closed_at_ms=close_ms, + trading_day="2026-06-14", + now=now, + ) + st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) + st = enrich_risk_status_countdown(st, now=now, daily_reset_hour=8) + self.assertGreater(st["freeze_remaining_sec"], 0) + self.assertEqual(st["freeze_until_ms"], st["cooloff_until_ms"]) + + on_journal_saved( + conn, + early_exit_trigger="止损", + early_exit_note="", + mood_issues_raw="扛单", + trading_day="2026-06-14", + now=now, + ) + st2 = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) + st2 = enrich_risk_status_countdown(st2, now=now, daily_reset_hour=8) + self.assertTrue(st2["daily_frozen"]) + self.assertGreater(st2["freeze_remaining_sec"], 0) + self.assertIsNotNone(st2["freeze_until_ms"]) + + def test_disabled_risk_control(self): + os.environ["RISK_CONTROL_ENABLED"] = "0" + conn = _mem_conn() + now = datetime(2026, 6, 14, 12, 0, 0) + on_user_initiated_close( + conn, source=CLOSE_SOURCE_USER_INSTANCE, trading_day="2026-06-14", now=now + ) + st = compute_account_risk_status(conn, trading_day="2026-06-14", now=now) + self.assertFalse(st["enabled"]) + self.assertTrue(st["can_trade"]) + ok, _ = account_risk_blocks_trading(conn, trading_day="2026-06-14", now=now) + self.assertTrue(ok) + + def test_position_limit_freeze_from_env(self): + os.environ["MAX_ACTIVE_POSITIONS"] = "2" + st = apply_position_limit_risk({"status": STATUS_NORMAL, "can_trade": True}, 2) + self.assertEqual(st["status"], STATUS_FREEZE_POSITION) + self.assertEqual(st["status_label"], "仓位上限冻结") + self.assertFalse(st["can_trade"]) + self.assertIn("2/2", st["reason"]) + self.assertIn("顺势加仓", st["reason"]) + self.assertTrue(st.get("can_roll")) + self.assertEqual(st["max_active_positions"], 2) + + def test_position_limit_normal_when_under_cap(self): + st = apply_position_limit_risk({"status": STATUS_NORMAL, "can_trade": True}, 0, max_active_positions=1) + self.assertEqual(st["status"], STATUS_NORMAL) + self.assertTrue(st["can_trade"]) + + def test_time_freeze_takes_priority_over_position_limit(self): + st = apply_position_limit_risk( + {"status": STATUS_FREEZE_4H, "status_label": "4h冻结", "can_trade": False}, + 5, + max_active_positions=1, + ) + self.assertEqual(st["status"], STATUS_FREEZE_4H) + self.assertEqual(st["active_count"], 5) + + def test_max_active_positions_from_env(self): + os.environ["MAX_ACTIVE_POSITIONS"] = "3" + self.assertEqual(max_active_positions_from_env(), 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_ai_review_lib.py b/tests/test_ai_review_lib.py index beb7690..1292ce6 100644 --- a/tests/test_ai_review_lib.py +++ b/tests/test_ai_review_lib.py @@ -1,63 +1,63 @@ -"""AI 复盘 journal 文本格式化(四所共用)。""" -from __future__ import annotations - -import sqlite3 -import sys -import unittest -from pathlib import Path - -ROOT = Path(__file__).resolve().parents[1] -sys.path.insert(0, str(ROOT)) - -from ai_review_lib import journal_row_lines_for_ai # noqa: E402 - - -class TestAiReviewLib(unittest.TestCase): - def test_journal_row_includes_expect_and_actual_rr(self): - text = journal_row_lines_for_ai( - 1, - { - "coin": "HYPE", - "tf": "5m", - "pnl": "10.73", - "real_rr": "2.1354", - "expect_rr": "-", - "entry_reason": "趋势回调", - "exit_reason": "移动止盈", - "hold_duration": "1天 3小时", - "mood_issues": "", - "post_breakeven_stare": "否", - "new_trade_while_occupied": "否", - "note": "测试备注", - }, - ) - self.assertIn("实际RR:2.1354", text) - self.assertIn("预期RR:-", text) - self.assertIn("开仓逻辑:趋势回调", text) - self.assertIn("备注:测试备注", text) - self.assertNotIn("开仓类型", text) - - def test_journal_row_accepts_sqlite_row(self): - conn = sqlite3.connect(":memory:") - conn.row_factory = sqlite3.Row - conn.execute( - """CREATE TABLE journal_entries ( - coin TEXT, tf TEXT, pnl TEXT, real_rr TEXT, expect_rr TEXT, - entry_reason TEXT, exit_reason TEXT, hold_duration TEXT, - mood_issues TEXT, mood_score INTEGER, note TEXT - )""" - ) - conn.execute( - """INSERT INTO journal_entries VALUES (?,?,?,?,?,?,?,?,?,?,?)""", - ("BTC", "15m", "5", "1.2", "2.0", "突破", "止盈", "2小时", "", None, ""), - ) - row = conn.execute("SELECT * FROM journal_entries").fetchone() - conn.close() - text = journal_row_lines_for_ai(1, row) - self.assertIn("BTC 15m", text) - self.assertIn("实际RR:1.2", text) - self.assertIn("开仓逻辑:突破", text) - - -if __name__ == "__main__": - unittest.main() +"""AI 复盘 journal 文本格式化(四所共用)。""" +from __future__ import annotations + +import sqlite3 +import sys +import unittest +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT)) + +from lib.ai.ai_review_lib import journal_row_lines_for_ai # noqa: E402 + + +class TestAiReviewLib(unittest.TestCase): + def test_journal_row_includes_expect_and_actual_rr(self): + text = journal_row_lines_for_ai( + 1, + { + "coin": "HYPE", + "tf": "5m", + "pnl": "10.73", + "real_rr": "2.1354", + "expect_rr": "-", + "entry_reason": "趋势回调", + "exit_reason": "移动止盈", + "hold_duration": "1天 3小时", + "mood_issues": "", + "post_breakeven_stare": "否", + "new_trade_while_occupied": "否", + "note": "测试备注", + }, + ) + self.assertIn("实际RR:2.1354", text) + self.assertIn("预期RR:-", text) + self.assertIn("开仓逻辑:趋势回调", text) + self.assertIn("备注:测试备注", text) + self.assertNotIn("开仓类型", text) + + def test_journal_row_accepts_sqlite_row(self): + conn = sqlite3.connect(":memory:") + conn.row_factory = sqlite3.Row + conn.execute( + """CREATE TABLE journal_entries ( + coin TEXT, tf TEXT, pnl TEXT, real_rr TEXT, expect_rr TEXT, + entry_reason TEXT, exit_reason TEXT, hold_duration TEXT, + mood_issues TEXT, mood_score INTEGER, note TEXT + )""" + ) + conn.execute( + """INSERT INTO journal_entries VALUES (?,?,?,?,?,?,?,?,?,?,?)""", + ("BTC", "15m", "5", "1.2", "2.0", "突破", "止盈", "2小时", "", None, ""), + ) + row = conn.execute("SELECT * FROM journal_entries").fetchone() + conn.close() + text = journal_row_lines_for_ai(1, row) + self.assertIn("BTC 15m", text) + self.assertIn("实际RR:1.2", text) + self.assertIn("开仓逻辑:突破", text) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_archive_calendar.py b/tests/test_archive_calendar.py index 6882eed..e60a579 100644 --- a/tests/test_archive_calendar.py +++ b/tests/test_archive_calendar.py @@ -1,60 +1,60 @@ -import sqlite3 -import tempfile -import unittest -from datetime import datetime -from pathlib import Path -from zoneinfo import ZoneInfo - -from hub_symbol_archive_lib import init_db, list_archive_calendar, upsert_trades_cache, upsert_trade_overlay - - -def _bj_ms(y, m, d, hh, mm): - dt = datetime(y, m, d, hh, mm, 0, tzinfo=ZoneInfo("Asia/Shanghai")) - return int(dt.timestamp() * 1000) - - -class ArchiveCalendarTests(unittest.TestCase): - def test_calendar_groups_by_trading_day_and_sick(self): - with tempfile.TemporaryDirectory() as td: - db = Path(td) / "arch.db" - init_db(db) - upsert_trades_cache( - "binance", - [ - { - "id": 1, - "symbol": "BTC/USDT", - "direction": "long", - "result": "止盈", - "pnl_amount": 10.0, - "opened_at": "2026-06-18 09:00:00", - "closed_at": "2026-06-18 10:00:00", - "closed_at_ms": _bj_ms(2026, 6, 18, 10, 0), - "exchange_turnover_usdt": 2000.0, - "exchange_commission_usdt": 0.8, - }, - { - "id": 2, - "symbol": "ETH/USDT", - "direction": "short", - "result": "止损", - "pnl_amount": -5.0, - "opened_at": "2026-06-18 14:00:00", - "closed_at": "2026-06-18 15:00:00", - "closed_at_ms": _bj_ms(2026, 6, 18, 15, 0), - }, - ], - db_path=db, - ) - upsert_trade_overlay("binance", 2, behavior_tag="sick", db_path=db) - payload = list_archive_calendar(2026, 6, db_path=db) - self.assertEqual(payload["month"], 6) - days = payload["days"] - self.assertTrue(days) - sick_days = [d for d in days.values() if d.get("has_sick")] - self.assertTrue(sick_days) - self.assertGreaterEqual(payload["month_open_count"], 2) - - -if __name__ == "__main__": - unittest.main() +import sqlite3 +import tempfile +import unittest +from datetime import datetime +from pathlib import Path +from zoneinfo import ZoneInfo + +from lib.hub.hub_symbol_archive_lib import init_db, list_archive_calendar, upsert_trades_cache, upsert_trade_overlay + + +def _bj_ms(y, m, d, hh, mm): + dt = datetime(y, m, d, hh, mm, 0, tzinfo=ZoneInfo("Asia/Shanghai")) + return int(dt.timestamp() * 1000) + + +class ArchiveCalendarTests(unittest.TestCase): + def test_calendar_groups_by_trading_day_and_sick(self): + with tempfile.TemporaryDirectory() as td: + db = Path(td) / "arch.db" + init_db(db) + upsert_trades_cache( + "binance", + [ + { + "id": 1, + "symbol": "BTC/USDT", + "direction": "long", + "result": "止盈", + "pnl_amount": 10.0, + "opened_at": "2026-06-18 09:00:00", + "closed_at": "2026-06-18 10:00:00", + "closed_at_ms": _bj_ms(2026, 6, 18, 10, 0), + "exchange_turnover_usdt": 2000.0, + "exchange_commission_usdt": 0.8, + }, + { + "id": 2, + "symbol": "ETH/USDT", + "direction": "short", + "result": "止损", + "pnl_amount": -5.0, + "opened_at": "2026-06-18 14:00:00", + "closed_at": "2026-06-18 15:00:00", + "closed_at_ms": _bj_ms(2026, 6, 18, 15, 0), + }, + ], + db_path=db, + ) + upsert_trade_overlay("binance", 2, behavior_tag="sick", db_path=db) + payload = list_archive_calendar(2026, 6, db_path=db) + self.assertEqual(payload["month"], 6) + days = payload["days"] + self.assertTrue(days) + sick_days = [d for d in days.values() if d.get("has_sick")] + self.assertTrue(sick_days) + self.assertGreaterEqual(payload["month_open_count"], 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_daily_open_limit_lib.py b/tests/test_daily_open_limit_lib.py index 00d0c52..e501967 100644 --- a/tests/test_daily_open_limit_lib.py +++ b/tests/test_daily_open_limit_lib.py @@ -1,90 +1,90 @@ -import unittest - -from daily_open_limit_lib import ( - build_daily_open_alert_prompt, - can_trade_new_open, - check_daily_open_hard_limit, - count_opens_for_trading_day, - daily_open_hard_limit_blocks, - format_daily_open_counter_line, - hard_limit_block_reason, - load_daily_open_limits_from_env, - parse_daily_open_hard_limit, - should_send_daily_open_alert, -) - - -class _FakeConn: - def __init__(self, count: int): - self._count = count - - def execute(self, _sql, _params): - return self - - def fetchone(self): - return (self._count,) - - -class DailyOpenLimitLibTests(unittest.TestCase): - def test_parse_hard_limit_zero_disables(self): - self.assertEqual(parse_daily_open_hard_limit("0"), 0) - self.assertEqual(parse_daily_open_hard_limit(None, default=0), 0) - - def test_load_from_env(self): - alert, hard = load_daily_open_limits_from_env( - {"DAILY_OPEN_ALERT_THRESHOLD": "3", "DAILY_OPEN_HARD_LIMIT": "8"} - ) - self.assertEqual(alert, 3) - self.assertEqual(hard, 8) - - def test_hard_limit_blocks(self): - self.assertFalse(daily_open_hard_limit_blocks(4, 0)) - self.assertFalse(daily_open_hard_limit_blocks(4, 5)) - self.assertTrue(daily_open_hard_limit_blocks(5, 5)) - - def test_check_daily_open_hard_limit(self): - conn = _FakeConn(5) - ok, reason, n = check_daily_open_hard_limit(conn, "2026-06-07", 5, 8) - self.assertFalse(ok) - self.assertEqual(n, 5) - self.assertIn("已达上限", reason) - self.assertIn("8:00", reason) - - def test_count_opens(self): - self.assertEqual(count_opens_for_trading_day(_FakeConn(3), "2026-06-07"), 3) - - def test_can_trade_new_open(self): - self.assertTrue( - can_trade_new_open( - time_allows=True, - active_count=0, - max_active_positions=1, - opens_today=2, - hard_limit=5, - ) - ) - self.assertFalse( - can_trade_new_open( - time_allows=True, - active_count=0, - max_active_positions=1, - opens_today=5, - hard_limit=5, - ) - ) - - def test_alert_crossing(self): - self.assertTrue(should_send_daily_open_alert(4, 5, 5)) - self.assertFalse(should_send_daily_open_alert(5, 6, 5)) - - def test_prompt_includes_hard_limit(self): - txt = build_daily_open_alert_prompt("2026-06-07", 5, 5, hard_limit=8) - self.assertIn("硬上限 8", txt) - - def test_counter_line(self): - line = format_daily_open_counter_line(3, 5, 8) - self.assertIn("3 / 硬上限 8", line) - - -if __name__ == "__main__": - unittest.main() +import unittest + +from lib.trade.daily_open_limit_lib import ( + build_daily_open_alert_prompt, + can_trade_new_open, + check_daily_open_hard_limit, + count_opens_for_trading_day, + daily_open_hard_limit_blocks, + format_daily_open_counter_line, + hard_limit_block_reason, + load_daily_open_limits_from_env, + parse_daily_open_hard_limit, + should_send_daily_open_alert, +) + + +class _FakeConn: + def __init__(self, count: int): + self._count = count + + def execute(self, _sql, _params): + return self + + def fetchone(self): + return (self._count,) + + +class DailyOpenLimitLibTests(unittest.TestCase): + def test_parse_hard_limit_zero_disables(self): + self.assertEqual(parse_daily_open_hard_limit("0"), 0) + self.assertEqual(parse_daily_open_hard_limit(None, default=0), 0) + + def test_load_from_env(self): + alert, hard = load_daily_open_limits_from_env( + {"DAILY_OPEN_ALERT_THRESHOLD": "3", "DAILY_OPEN_HARD_LIMIT": "8"} + ) + self.assertEqual(alert, 3) + self.assertEqual(hard, 8) + + def test_hard_limit_blocks(self): + self.assertFalse(daily_open_hard_limit_blocks(4, 0)) + self.assertFalse(daily_open_hard_limit_blocks(4, 5)) + self.assertTrue(daily_open_hard_limit_blocks(5, 5)) + + def test_check_daily_open_hard_limit(self): + conn = _FakeConn(5) + ok, reason, n = check_daily_open_hard_limit(conn, "2026-06-07", 5, 8) + self.assertFalse(ok) + self.assertEqual(n, 5) + self.assertIn("已达上限", reason) + self.assertIn("8:00", reason) + + def test_count_opens(self): + self.assertEqual(count_opens_for_trading_day(_FakeConn(3), "2026-06-07"), 3) + + def test_can_trade_new_open(self): + self.assertTrue( + can_trade_new_open( + time_allows=True, + active_count=0, + max_active_positions=1, + opens_today=2, + hard_limit=5, + ) + ) + self.assertFalse( + can_trade_new_open( + time_allows=True, + active_count=0, + max_active_positions=1, + opens_today=5, + hard_limit=5, + ) + ) + + def test_alert_crossing(self): + self.assertTrue(should_send_daily_open_alert(4, 5, 5)) + self.assertFalse(should_send_daily_open_alert(5, 6, 5)) + + def test_prompt_includes_hard_limit(self): + txt = build_daily_open_alert_prompt("2026-06-07", 5, 5, hard_limit=8) + self.assertIn("硬上限 8", txt) + + def test_counter_line(self): + line = format_daily_open_counter_line(3, 5, 8) + self.assertIn("3 / 硬上限 8", line) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_false_breakout_key_monitor_lib.py b/tests/test_false_breakout_key_monitor_lib.py index 3d29c1f..92ccf7a 100644 --- a/tests/test_false_breakout_key_monitor_lib.py +++ b/tests/test_false_breakout_key_monitor_lib.py @@ -1,76 +1,76 @@ -import unittest -from datetime import datetime, timedelta - -from false_breakout_key_monitor_lib import ( - FALSE_BREAKOUT_MONITOR_TYPE, - calc_false_breakout_plan, - false_breakout_gate_preview, - is_false_breakout_expired, - key_price_from_row, - normalize_false_breakout_symbol, - storage_bounds_from_key_price, -) - - -class FalseBreakoutKeyMonitorLibTests(unittest.TestCase): - def test_normalize_symbol(self): - self.assertEqual(normalize_false_breakout_symbol("btc"), "BTC/USDT") - self.assertEqual(normalize_false_breakout_symbol("ETH/USDT"), "ETH/USDT") - self.assertIsNone(normalize_false_breakout_symbol("SOL")) - - def test_short_plan(self): - plan = calc_false_breakout_plan("short", 100000) - self.assertIsNotNone(plan) - entry, sl, tp = plan - self.assertAlmostEqual(entry, 100100.0) - self.assertAlmostEqual(sl, 100600.5) - self.assertAlmostEqual(tp, 99349.25) - - def test_long_plan(self): - plan = calc_false_breakout_plan("long", 100000) - self.assertIsNotNone(plan) - entry, sl, tp = plan - self.assertAlmostEqual(entry, 99900.0) - self.assertAlmostEqual(sl, 99400.5) - self.assertAlmostEqual(tp, 100649.25) - - def test_storage_bounds(self): - up, low = storage_bounds_from_key_price("short", 100000) - self.assertGreater(up, low) - self.assertAlmostEqual(up, 100000.0) - self.assertAlmostEqual(low, 99990.0) - up, low = storage_bounds_from_key_price("long", 100000) - self.assertGreater(up, low) - self.assertAlmostEqual(low, 100000.0) - self.assertAlmostEqual(up, 100010.0) - - def test_key_price_from_row(self): - self.assertEqual(key_price_from_row("short", 100100, 100000), 100100) - self.assertEqual(key_price_from_row("long", 100100, 100000), 100000) - - def test_expiry(self): - now = datetime(2026, 6, 9, 12, 0, 0) - created = "2026-06-08 12:00:00" - self.assertTrue(is_false_breakout_expired(created, now)) - self.assertFalse(is_false_breakout_expired(created, now - timedelta(hours=1))) - - def test_monitor_type_constant(self): - self.assertEqual(FALSE_BREAKOUT_MONITOR_TYPE, "假突破") - - def test_gate_preview_not_box_gate(self): - now = datetime(2026, 6, 7, 12, 0, 0) - prev = false_breakout_gate_preview( - entry_display="1635.0", - limit_order_id="oid-1", - created_at="2026-06-07 10:00:00", - now=now, - ) - self.assertIn("假突破", prev["summary"]) - self.assertIn("等待成交", prev["summary"]) - self.assertNotIn("量:", prev["summary"]) - self.assertIn("限价单:oid-1", prev["metrics"]) - self.assertTrue(prev["gate_ok"]) - - -if __name__ == "__main__": - unittest.main() +import unittest +from datetime import datetime, timedelta + +from lib.key_monitor.false_breakout_key_monitor_lib import ( + FALSE_BREAKOUT_MONITOR_TYPE, + calc_false_breakout_plan, + false_breakout_gate_preview, + is_false_breakout_expired, + key_price_from_row, + normalize_false_breakout_symbol, + storage_bounds_from_key_price, +) + + +class FalseBreakoutKeyMonitorLibTests(unittest.TestCase): + def test_normalize_symbol(self): + self.assertEqual(normalize_false_breakout_symbol("btc"), "BTC/USDT") + self.assertEqual(normalize_false_breakout_symbol("ETH/USDT"), "ETH/USDT") + self.assertIsNone(normalize_false_breakout_symbol("SOL")) + + def test_short_plan(self): + plan = calc_false_breakout_plan("short", 100000) + self.assertIsNotNone(plan) + entry, sl, tp = plan + self.assertAlmostEqual(entry, 100100.0) + self.assertAlmostEqual(sl, 100600.5) + self.assertAlmostEqual(tp, 99349.25) + + def test_long_plan(self): + plan = calc_false_breakout_plan("long", 100000) + self.assertIsNotNone(plan) + entry, sl, tp = plan + self.assertAlmostEqual(entry, 99900.0) + self.assertAlmostEqual(sl, 99400.5) + self.assertAlmostEqual(tp, 100649.25) + + def test_storage_bounds(self): + up, low = storage_bounds_from_key_price("short", 100000) + self.assertGreater(up, low) + self.assertAlmostEqual(up, 100000.0) + self.assertAlmostEqual(low, 99990.0) + up, low = storage_bounds_from_key_price("long", 100000) + self.assertGreater(up, low) + self.assertAlmostEqual(low, 100000.0) + self.assertAlmostEqual(up, 100010.0) + + def test_key_price_from_row(self): + self.assertEqual(key_price_from_row("short", 100100, 100000), 100100) + self.assertEqual(key_price_from_row("long", 100100, 100000), 100000) + + def test_expiry(self): + now = datetime(2026, 6, 9, 12, 0, 0) + created = "2026-06-08 12:00:00" + self.assertTrue(is_false_breakout_expired(created, now)) + self.assertFalse(is_false_breakout_expired(created, now - timedelta(hours=1))) + + def test_monitor_type_constant(self): + self.assertEqual(FALSE_BREAKOUT_MONITOR_TYPE, "假突破") + + def test_gate_preview_not_box_gate(self): + now = datetime(2026, 6, 7, 12, 0, 0) + prev = false_breakout_gate_preview( + entry_display="1635.0", + limit_order_id="oid-1", + created_at="2026-06-07 10:00:00", + now=now, + ) + self.assertIn("假突破", prev["summary"]) + self.assertIn("等待成交", prev["summary"]) + self.assertNotIn("量:", prev["summary"]) + self.assertIn("限价单:oid-1", prev["metrics"]) + self.assertTrue(prev["gate_ok"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_gate_position_history_lib.py b/tests/test_gate_position_history_lib.py index cecbf5e..919c9d0 100644 --- a/tests/test_gate_position_history_lib.py +++ b/tests/test_gate_position_history_lib.py @@ -1,26 +1,26 @@ -from gate_position_history_lib import pick_gate_position_close, unified_symbol_for_match - - -def test_unified_symbol_strips_settle_suffix(): - assert unified_symbol_for_match("BTC/USDT:USDT") == "BTC/USDT" - - -def test_pick_gate_position_close_matches_symbol_side_and_time(): - hist = [ - { - "symbol_u": "SOL/USDT", - "side": "short", - "close_ms": 1_700_000_000_000, - "open_ms": 1_699_999_000_000, - "pnl": -1.25, - "sync_key": "SOL_USDT|1|short", - } - ] - hit = pick_gate_position_close( - hist, - "SOL/USDT:USDT", - "short", - opened_at_ms=1_699_999_500_000, - ) - assert hit is not None - assert hit["pnl"] == -1.25 +from lib.exchange.gate_position_history_lib import pick_gate_position_close, unified_symbol_for_match + + +def test_unified_symbol_strips_settle_suffix(): + assert unified_symbol_for_match("BTC/USDT:USDT") == "BTC/USDT" + + +def test_pick_gate_position_close_matches_symbol_side_and_time(): + hist = [ + { + "symbol_u": "SOL/USDT", + "side": "short", + "close_ms": 1_700_000_000_000, + "open_ms": 1_699_999_000_000, + "pnl": -1.25, + "sync_key": "SOL_USDT|1|short", + } + ] + hit = pick_gate_position_close( + hist, + "SOL/USDT:USDT", + "short", + opened_at_ms=1_699_999_500_000, + ) + assert hit is not None + assert hit["pnl"] == -1.25 diff --git a/tests/test_gate_transfer_lib.py b/tests/test_gate_transfer_lib.py index bb8ed09..4506492 100644 --- a/tests/test_gate_transfer_lib.py +++ b/tests/test_gate_transfer_lib.py @@ -1,44 +1,44 @@ -"""gate_transfer_lib 单元测试。""" -from __future__ import annotations - -import sqlite3 -import unittest - -from gate_transfer_lib import count_auto_transfer_blockers - - -class GateTransferLibTest(unittest.TestCase): - def test_counts_order_monitors_first(self): - conn = sqlite3.connect(":memory:") - conn.execute("CREATE TABLE order_monitors (status TEXT)") - conn.execute("CREATE TABLE trend_pullback_plans (status TEXT, first_order_done INTEGER)") - conn.execute("INSERT INTO order_monitors VALUES ('active')") - conn.execute("INSERT INTO trend_pullback_plans VALUES ('active', 1)") - conn.commit() - n = count_auto_transfer_blockers(conn, count_order_monitors=lambda c: 1) - self.assertEqual(n, 1) - conn.close() - - def test_counts_trend_plan_when_no_order_monitors(self): - conn = sqlite3.connect(":memory:") - conn.execute("CREATE TABLE order_monitors (status TEXT)") - conn.execute("CREATE TABLE trend_pullback_plans (status TEXT, first_order_done INTEGER)") - conn.execute("INSERT INTO trend_pullback_plans VALUES ('active', 1)") - conn.commit() - n = count_auto_transfer_blockers(conn, count_order_monitors=lambda c: 0) - self.assertEqual(n, 1) - conn.close() - - def test_ignores_trend_plan_without_first_order(self): - conn = sqlite3.connect(":memory:") - conn.execute("CREATE TABLE order_monitors (status TEXT)") - conn.execute("CREATE TABLE trend_pullback_plans (status TEXT, first_order_done INTEGER)") - conn.execute("INSERT INTO trend_pullback_plans VALUES ('active', 0)") - conn.commit() - n = count_auto_transfer_blockers(conn, count_order_monitors=lambda c: 0) - self.assertEqual(n, 0) - conn.close() - - -if __name__ == "__main__": - unittest.main() +"""gate_transfer_lib 单元测试。""" +from __future__ import annotations + +import sqlite3 +import unittest + +from lib.exchange.gate_transfer_lib import count_auto_transfer_blockers + + +class GateTransferLibTest(unittest.TestCase): + def test_counts_order_monitors_first(self): + conn = sqlite3.connect(":memory:") + conn.execute("CREATE TABLE order_monitors (status TEXT)") + conn.execute("CREATE TABLE trend_pullback_plans (status TEXT, first_order_done INTEGER)") + conn.execute("INSERT INTO order_monitors VALUES ('active')") + conn.execute("INSERT INTO trend_pullback_plans VALUES ('active', 1)") + conn.commit() + n = count_auto_transfer_blockers(conn, count_order_monitors=lambda c: 1) + self.assertEqual(n, 1) + conn.close() + + def test_counts_trend_plan_when_no_order_monitors(self): + conn = sqlite3.connect(":memory:") + conn.execute("CREATE TABLE order_monitors (status TEXT)") + conn.execute("CREATE TABLE trend_pullback_plans (status TEXT, first_order_done INTEGER)") + conn.execute("INSERT INTO trend_pullback_plans VALUES ('active', 1)") + conn.commit() + n = count_auto_transfer_blockers(conn, count_order_monitors=lambda c: 0) + self.assertEqual(n, 1) + conn.close() + + def test_ignores_trend_plan_without_first_order(self): + conn = sqlite3.connect(":memory:") + conn.execute("CREATE TABLE order_monitors (status TEXT)") + conn.execute("CREATE TABLE trend_pullback_plans (status TEXT, first_order_done INTEGER)") + conn.execute("INSERT INTO trend_pullback_plans VALUES ('active', 0)") + conn.commit() + n = count_auto_transfer_blockers(conn, count_order_monitors=lambda c: 0) + self.assertEqual(n, 0) + conn.close() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_hub_agent_mark_price.py b/tests/test_hub_agent_mark_price.py index e80d922..c042706 100644 --- a/tests/test_hub_agent_mark_price.py +++ b/tests/test_hub_agent_mark_price.py @@ -1,94 +1,94 @@ -"""子代理持仓:四所标记价字段统一解析。""" -from __future__ import annotations - -import sys -import unittest -from pathlib import Path - -ROOT = Path(__file__).resolve().parents[1] -sys.path.insert(0, str(ROOT / "manual_trading_hub")) - -from agent import _position_mark_price, _ticker_mark_price # noqa: E402 - -sys.path.insert(0, str(ROOT)) -from hub_position_metrics import ( # noqa: E402 - enrich_ccxt_position_metrics_out, - estimate_linear_swap_upnl_usdt, - parse_position_unrealized_pnl, - resolve_position_display_upnl, -) - - -class TestHubAgentMarkPrice(unittest.TestCase): - def test_binance_mark_price(self): - px = _position_mark_price({"markPrice": 65880.1, "info": {}}) - self.assertAlmostEqual(px, 65880.1) - - def test_okx_mark_px(self): - px = _position_mark_price({"info": {"markPx": "72.85"}}) - self.assertAlmostEqual(px, 72.85) - - def test_gate_info_mark(self): - px = _position_mark_price({"info": {"mark_price": "0.2241"}}) - self.assertAlmostEqual(px, 0.2241) - - def test_missing_returns_none(self): - self.assertIsNone(_position_mark_price({"info": {}})) - - def test_infer_from_notional_and_contracts(self): - p = {"notional": 1000, "contracts": 10, "info": {}} - px = _position_mark_price(p) - self.assertAlmostEqual(px, 100.0) - - def test_ticker_fallback(self): - class _Ex: - def fetch_ticker(self, sym): - return {"mark": 99.5, "info": {}} - - self.assertAlmostEqual(_ticker_mark_price(_Ex(), "BTC/USDT:USDT"), 99.5) - - def test_gate_unrealised_pnl_in_info(self): - pnl = parse_position_unrealized_pnl( - {"info": {"unrealised_pnl": "6.81"}, "unrealizedPnl": None} - ) - self.assertAlmostEqual(pnl, 6.81) - - def test_okx_upl_signed(self): - pnl = parse_position_unrealized_pnl( - {"info": {"upl": "-2.15"}, "unrealizedPnl": None} - ) - self.assertAlmostEqual(pnl, -2.15) - - def test_enrich_aligns_short_gate_metrics(self): - pos = { - "side": "short", - "contracts": 11, - "entryPrice": 73.187, - "markPrice": 66.038, - "info": {"unrealised_pnl": "7.86"}, - } - out = {"unrealized_pnl": 7.86, "mark_price": 66.038} - enrich_ccxt_position_metrics_out(pos, out, contract_size=1.0, funds_decimals=2) - self.assertGreater(out["unrealized_pnl"], 70.0) - - def test_estimate_short_hype_contract_size(self): - upnl = estimate_linear_swap_upnl_usdt( - "short", 73.187, 66.038, 11, 0.1 - ) - self.assertAlmostEqual(upnl, 7.86, places=1) - - def test_resolve_prefers_computed_when_exchange_off(self): - shown = resolve_position_display_upnl( - "short", 73.187, 66.038, 11, 1.0, 7.86 - ) - self.assertAlmostEqual(shown, 78.64, places=1) - - def test_resolve_keeps_exchange_when_aligned(self): - shown = resolve_position_display_upnl( - "short", 73.187, 66.038, 11, 0.1, 7.86 - ) - self.assertAlmostEqual(shown, 7.86, places=2) - - -if __name__ == "__main__": - unittest.main() +"""子代理持仓:四所标记价字段统一解析。""" +from __future__ import annotations + +import sys +import unittest +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT / "manual_trading_hub")) + +from agent import _position_mark_price, _ticker_mark_price # noqa: E402 + +sys.path.insert(0, str(ROOT)) +from lib.hub.hub_position_metrics import ( # noqa: E402 + enrich_ccxt_position_metrics_out, + estimate_linear_swap_upnl_usdt, + parse_position_unrealized_pnl, + resolve_position_display_upnl, +) + + +class TestHubAgentMarkPrice(unittest.TestCase): + def test_binance_mark_price(self): + px = _position_mark_price({"markPrice": 65880.1, "info": {}}) + self.assertAlmostEqual(px, 65880.1) + + def test_okx_mark_px(self): + px = _position_mark_price({"info": {"markPx": "72.85"}}) + self.assertAlmostEqual(px, 72.85) + + def test_gate_info_mark(self): + px = _position_mark_price({"info": {"mark_price": "0.2241"}}) + self.assertAlmostEqual(px, 0.2241) + + def test_missing_returns_none(self): + self.assertIsNone(_position_mark_price({"info": {}})) + + def test_infer_from_notional_and_contracts(self): + p = {"notional": 1000, "contracts": 10, "info": {}} + px = _position_mark_price(p) + self.assertAlmostEqual(px, 100.0) + + def test_ticker_fallback(self): + class _Ex: + def fetch_ticker(self, sym): + return {"mark": 99.5, "info": {}} + + self.assertAlmostEqual(_ticker_mark_price(_Ex(), "BTC/USDT:USDT"), 99.5) + + def test_gate_unrealised_pnl_in_info(self): + pnl = parse_position_unrealized_pnl( + {"info": {"unrealised_pnl": "6.81"}, "unrealizedPnl": None} + ) + self.assertAlmostEqual(pnl, 6.81) + + def test_okx_upl_signed(self): + pnl = parse_position_unrealized_pnl( + {"info": {"upl": "-2.15"}, "unrealizedPnl": None} + ) + self.assertAlmostEqual(pnl, -2.15) + + def test_enrich_aligns_short_gate_metrics(self): + pos = { + "side": "short", + "contracts": 11, + "entryPrice": 73.187, + "markPrice": 66.038, + "info": {"unrealised_pnl": "7.86"}, + } + out = {"unrealized_pnl": 7.86, "mark_price": 66.038} + enrich_ccxt_position_metrics_out(pos, out, contract_size=1.0, funds_decimals=2) + self.assertGreater(out["unrealized_pnl"], 70.0) + + def test_estimate_short_hype_contract_size(self): + upnl = estimate_linear_swap_upnl_usdt( + "short", 73.187, 66.038, 11, 0.1 + ) + self.assertAlmostEqual(upnl, 7.86, places=1) + + def test_resolve_prefers_computed_when_exchange_off(self): + shown = resolve_position_display_upnl( + "short", 73.187, 66.038, 11, 1.0, 7.86 + ) + self.assertAlmostEqual(shown, 78.64, places=1) + + def test_resolve_keeps_exchange_when_aligned(self): + shown = resolve_position_display_upnl( + "short", 73.187, 66.038, 11, 0.1, 7.86 + ) + self.assertAlmostEqual(shown, 7.86, places=2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_hub_calculator_lib.py b/tests/test_hub_calculator_lib.py index ef43206..82a1d61 100644 --- a/tests/test_hub_calculator_lib.py +++ b/tests/test_hub_calculator_lib.py @@ -1,162 +1,162 @@ -"""hub_calculator_lib 测算逻辑。""" - -import unittest -from unittest.mock import patch - -from hub_calculator_lib import ( - calc_initial_roll_qty, - calc_roll_calculator, - calc_trend_calculator, - solve_add_amount_for_total_risk, -) - -MOCK_MARKET = { - "exchange_id": "0", - "exchange_key": "binance", - "exchange_name": "币安 · crypto_monitor_binance", - "exchange_label": "币安 · crypto_monitor_binance", - "base": "ETH", - "exchange_symbol": "ETH/USDT:USDT", - "display_symbol": "ETH/USDT", - "contract_size": 1.0, - "price_tick": 0.01, - "price_decimals": 2, - "amount_decimals": 3, - "min_amount": 0.001, -} - - -def _mock_resolve(_exchange="binance", _base="ETH"): - return MOCK_MARKET, lambda amount: round(float(amount), 3), None - - -class HubCalculatorLibTests(unittest.TestCase): - @patch("hub_calculator_lib._resolve_market", return_value=_mock_resolve()) - def test_trend_calculator_long_basic(self, _mock): - data, err = calc_trend_calculator( - direction="long", - capital_usdt=1000, - risk_percent=5, - leverage=5, - entry_price=100, - stop_loss=95, - add_upper=110, - take_profit=120, - dca_legs=3, - exchange_id="0", - base="ETH", - ) - self.assertIsNone(err) - self.assertIsNotNone(data) - assert data is not None - self.assertEqual(data["risk_budget_u"], 50.0) - self.assertGreaterEqual(len(data["rows"]), 2) - self.assertEqual(data["rows"][0]["label"], "首仓") - self.assertEqual(data["market"]["display_symbol"], "ETH/USDT") - - @patch("hub_calculator_lib._resolve_market", return_value=_mock_resolve()) - def test_trend_calculator_short_rejects_bad_bounds(self, _mock): - data, err = calc_trend_calculator( - direction="short", - capital_usdt=1000, - risk_percent=5, - leverage=5, - entry_price=100, - stop_loss=90, - add_upper=110, - take_profit=80, - dca_legs=3, - ) - self.assertIsNone(data) - self.assertIsNotNone(err) - - @patch("hub_calculator_lib._resolve_market", return_value=_mock_resolve()) - def test_roll_calculator_first_leg_auto(self, _mock): - data, err = calc_roll_calculator( - direction="long", - capital_usdt=1000, - risk_percent=5, - entry_price=100, - stop_loss=95, - take_profit=120, - add_legs=[], - legs_done=0, - ) - self.assertIsNone(err) - self.assertIsNotNone(data) - assert data is not None - self.assertEqual(data["first_contracts"], 10.0) - self.assertEqual(len(data["rows"]), 1) - self.assertEqual(data["rows"][0]["loss_at_sl_u"], 50.0) - self.assertEqual(data["rows"][0]["profit_at_tp_u"], 200.0) - - @patch("hub_calculator_lib._resolve_market", return_value=_mock_resolve()) - def test_roll_calculator_chain_two_legs(self, _mock): - data, err = calc_roll_calculator( - direction="long", - capital_usdt=1000, - risk_percent=5, - entry_price=100, - stop_loss=95, - take_profit=120, - add_legs=[ - {"add_price": 105, "new_stop_loss": 98}, - {"add_price": 108, "new_stop_loss": 101}, - ], - legs_done=0, - ) - self.assertIsNone(err) - self.assertIsNotNone(data) - assert data is not None - self.assertEqual(len(data["rows"]), 3) - self.assertEqual(data["rows"][1]["label"], "滚仓1") - self.assertGreater(float(data["final_contracts"]), float(data["first_contracts"])) - - @patch("hub_calculator_lib._resolve_market", return_value=_mock_resolve()) - def test_roll_calculator_rejects_too_many_legs(self, _mock): - data, err = calc_roll_calculator( - direction="long", - capital_usdt=1000, - risk_percent=5, - entry_price=100, - stop_loss=95, - take_profit=120, - add_legs=[ - {"add_price": 105, "new_stop_loss": 98}, - {"add_price": 108, "new_stop_loss": 101}, - {"add_price": 110, "new_stop_loss": 103}, - {"add_price": 112, "new_stop_loss": 105}, - ], - legs_done=0, - ) - self.assertIsNone(data) - self.assertIsNotNone(err) - - def test_initial_roll_qty(self): - qty, err = calc_initial_roll_qty("long", 100, 95, 50, 1.0) - self.assertIsNone(err) - self.assertEqual(qty, 10.0) - - def test_initial_roll_qty_with_contract_size(self): - qty, err = calc_initial_roll_qty("long", 100, 95, 50, 0.1) - self.assertIsNone(err) - self.assertEqual(qty, 100.0) - - def test_solve_add_with_contract_size(self): - q2, err = solve_add_amount_for_total_risk( - "long", - qty_existing=10.0, - entry_existing=100.0, - add_price=105.0, - new_stop=98.0, - risk_budget_usdt=50.0, - contract_size=1.0, - ) - self.assertIsNone(err) - self.assertIsNotNone(q2) - assert q2 is not None - self.assertGreater(q2, 0) - - -if __name__ == "__main__": - unittest.main() +"""hub_calculator_lib 测算逻辑。""" + +import unittest +from unittest.mock import patch + +from lib.hub.hub_calculator_lib import ( + calc_initial_roll_qty, + calc_roll_calculator, + calc_trend_calculator, + solve_add_amount_for_total_risk, +) + +MOCK_MARKET = { + "exchange_id": "0", + "exchange_key": "binance", + "exchange_name": "币安 · crypto_monitor_binance", + "exchange_label": "币安 · crypto_monitor_binance", + "base": "ETH", + "exchange_symbol": "ETH/USDT:USDT", + "display_symbol": "ETH/USDT", + "contract_size": 1.0, + "price_tick": 0.01, + "price_decimals": 2, + "amount_decimals": 3, + "min_amount": 0.001, +} + + +def _mock_resolve(_exchange="binance", _base="ETH"): + return MOCK_MARKET, lambda amount: round(float(amount), 3), None + + +class HubCalculatorLibTests(unittest.TestCase): + @patch("lib.hub.hub_calculator_lib._resolve_market", return_value=_mock_resolve()) + def test_trend_calculator_long_basic(self, _mock): + data, err = calc_trend_calculator( + direction="long", + capital_usdt=1000, + risk_percent=5, + leverage=5, + entry_price=100, + stop_loss=95, + add_upper=110, + take_profit=120, + dca_legs=3, + exchange_id="0", + base="ETH", + ) + self.assertIsNone(err) + self.assertIsNotNone(data) + assert data is not None + self.assertEqual(data["risk_budget_u"], 50.0) + self.assertGreaterEqual(len(data["rows"]), 2) + self.assertEqual(data["rows"][0]["label"], "首仓") + self.assertEqual(data["market"]["display_symbol"], "ETH/USDT") + + @patch("lib.hub.hub_calculator_lib._resolve_market", return_value=_mock_resolve()) + def test_trend_calculator_short_rejects_bad_bounds(self, _mock): + data, err = calc_trend_calculator( + direction="short", + capital_usdt=1000, + risk_percent=5, + leverage=5, + entry_price=100, + stop_loss=90, + add_upper=110, + take_profit=80, + dca_legs=3, + ) + self.assertIsNone(data) + self.assertIsNotNone(err) + + @patch("lib.hub.hub_calculator_lib._resolve_market", return_value=_mock_resolve()) + def test_roll_calculator_first_leg_auto(self, _mock): + data, err = calc_roll_calculator( + direction="long", + capital_usdt=1000, + risk_percent=5, + entry_price=100, + stop_loss=95, + take_profit=120, + add_legs=[], + legs_done=0, + ) + self.assertIsNone(err) + self.assertIsNotNone(data) + assert data is not None + self.assertEqual(data["first_contracts"], 10.0) + self.assertEqual(len(data["rows"]), 1) + self.assertEqual(data["rows"][0]["loss_at_sl_u"], 50.0) + self.assertEqual(data["rows"][0]["profit_at_tp_u"], 200.0) + + @patch("lib.hub.hub_calculator_lib._resolve_market", return_value=_mock_resolve()) + def test_roll_calculator_chain_two_legs(self, _mock): + data, err = calc_roll_calculator( + direction="long", + capital_usdt=1000, + risk_percent=5, + entry_price=100, + stop_loss=95, + take_profit=120, + add_legs=[ + {"add_price": 105, "new_stop_loss": 98}, + {"add_price": 108, "new_stop_loss": 101}, + ], + legs_done=0, + ) + self.assertIsNone(err) + self.assertIsNotNone(data) + assert data is not None + self.assertEqual(len(data["rows"]), 3) + self.assertEqual(data["rows"][1]["label"], "滚仓1") + self.assertGreater(float(data["final_contracts"]), float(data["first_contracts"])) + + @patch("lib.hub.hub_calculator_lib._resolve_market", return_value=_mock_resolve()) + def test_roll_calculator_rejects_too_many_legs(self, _mock): + data, err = calc_roll_calculator( + direction="long", + capital_usdt=1000, + risk_percent=5, + entry_price=100, + stop_loss=95, + take_profit=120, + add_legs=[ + {"add_price": 105, "new_stop_loss": 98}, + {"add_price": 108, "new_stop_loss": 101}, + {"add_price": 110, "new_stop_loss": 103}, + {"add_price": 112, "new_stop_loss": 105}, + ], + legs_done=0, + ) + self.assertIsNone(data) + self.assertIsNotNone(err) + + def test_initial_roll_qty(self): + qty, err = calc_initial_roll_qty("long", 100, 95, 50, 1.0) + self.assertIsNone(err) + self.assertEqual(qty, 10.0) + + def test_initial_roll_qty_with_contract_size(self): + qty, err = calc_initial_roll_qty("long", 100, 95, 50, 0.1) + self.assertIsNone(err) + self.assertEqual(qty, 100.0) + + def test_solve_add_with_contract_size(self): + q2, err = solve_add_amount_for_total_risk( + "long", + qty_existing=10.0, + entry_existing=100.0, + add_price=105.0, + new_stop=98.0, + risk_budget_usdt=50.0, + contract_size=1.0, + ) + self.assertIsNone(err) + self.assertIsNotNone(q2) + assert q2 is not None + self.assertGreater(q2, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_hub_calculator_market_lib.py b/tests/test_hub_calculator_market_lib.py index b5ca102..ba07a04 100644 --- a/tests/test_hub_calculator_market_lib.py +++ b/tests/test_hub_calculator_market_lib.py @@ -1,113 +1,113 @@ -"""hub_calculator_market_lib 合约解析。""" - -import unittest -from unittest.mock import patch - -from hub_calculator_market_lib import ( - amount_decimals_from_exchange, - find_exchange, - get_calculator_market, - list_calculator_exchanges, - make_amount_precise_fn_from_market, - normalize_base_symbol, - resolve_usdt_perp_symbol, -) - - -class FakeExchange: - def __init__(self, markets: dict): - self.markets = markets - - def market(self, symbol: str): - return self.markets[symbol] - - def amount_to_precision(self, symbol: str, amount: float) -> str: - return f"{float(amount):.3f}" - - -class HubCalculatorMarketLibTests(unittest.TestCase): - def test_normalize_base_symbol(self): - self.assertEqual(normalize_base_symbol("eth"), "ETH") - self.assertEqual(normalize_base_symbol("ETH/USDT:USDT"), "ETH") - self.assertEqual(normalize_base_symbol("ETHUSDT"), "ETH") - - def test_resolve_usdt_perp_symbol(self): - ex = FakeExchange( - { - "ETH/USDT:USDT": { - "base": "ETH", - "quote": "USDT", - "swap": True, - "active": True, - "contractSize": 1.0, - "limits": {"amount": {"min": 0.001}}, - "precision": {"price": 2, "amount": 3}, - } - } - ) - sym, err = resolve_usdt_perp_symbol(ex, "ETH") - self.assertIsNone(err) - self.assertEqual(sym, "ETH/USDT:USDT") - - def test_amount_decimals_from_exchange(self): - ex = FakeExchange({}) - self.assertEqual(amount_decimals_from_exchange(ex, "ETH/USDT:USDT"), 3) - - def test_make_amount_precise_fn_from_market(self): - fn = make_amount_precise_fn_from_market({"amount_decimals": 3, "min_amount": 0.001}) - self.assertEqual(fn(1.23456), 1.234) - self.assertIsNone(fn(0.0001)) - - @patch.dict("os.environ", {"HUB_BRIDGE_TOKEN": "test-token"}, clear=False) - def test_hub_headers_use_x_hub_token(self): - from hub_calculator_market_lib import _hub_headers - - self.assertEqual(_hub_headers(), {"X-Hub-Token": "test-token"}) - - @patch("hub_calculator_market_lib.fetch_instance_market_sync") - def test_get_calculator_market_from_instance(self, fetch_mock): - fetch_mock.return_value = { - "ok": True, - "base": "ETH", - "exchange_symbol": "ETH/USDT:USDT", - "display_symbol": "ETH/USDT", - "contract_size": 0.01, - "price_tick": 0.01, - "price_decimals": 2, - "amount_decimals": 2, - "min_amount": 0.01, - } - ex = { - "id": "0", - "key": "binance", - "name": "币安 · crypto_monitor_binance", - "enabled": True, - "flask_url": "http://127.0.0.1:5001", - } - data, err = get_calculator_market("0", "ETH", ex=ex) - self.assertIsNone(err) - self.assertIsNotNone(data) - assert data is not None - self.assertEqual(data["exchange_id"], "0") - self.assertEqual(data["exchange_name"], "币安 · crypto_monitor_binance") - self.assertEqual(data["contract_size"], 0.01) - - @patch("hub_calculator_market_lib.enabled_exchanges") - def test_list_calculator_exchanges(self, enabled_mock): - enabled_mock.return_value = [ - {"id": "0", "key": "binance", "name": "币安", "enabled": True}, - ] - rows = list_calculator_exchanges() - self.assertEqual(len(rows), 1) - self.assertEqual(rows[0]["id"], "0") - - def test_find_exchange_by_id(self): - with patch( - "hub_calculator_market_lib.load_settings", - return_value={"exchanges": [{"id": "2", "key": "gate", "name": "Gate"}]}, - ): - self.assertEqual(find_exchange("2")["key"], "gate") - - -if __name__ == "__main__": - unittest.main() +"""hub_calculator_market_lib 合约解析。""" + +import unittest +from unittest.mock import patch + +from lib.hub.hub_calculator_market_lib import ( + amount_decimals_from_exchange, + find_exchange, + get_calculator_market, + list_calculator_exchanges, + make_amount_precise_fn_from_market, + normalize_base_symbol, + resolve_usdt_perp_symbol, +) + + +class FakeExchange: + def __init__(self, markets: dict): + self.markets = markets + + def market(self, symbol: str): + return self.markets[symbol] + + def amount_to_precision(self, symbol: str, amount: float) -> str: + return f"{float(amount):.3f}" + + +class HubCalculatorMarketLibTests(unittest.TestCase): + def test_normalize_base_symbol(self): + self.assertEqual(normalize_base_symbol("eth"), "ETH") + self.assertEqual(normalize_base_symbol("ETH/USDT:USDT"), "ETH") + self.assertEqual(normalize_base_symbol("ETHUSDT"), "ETH") + + def test_resolve_usdt_perp_symbol(self): + ex = FakeExchange( + { + "ETH/USDT:USDT": { + "base": "ETH", + "quote": "USDT", + "swap": True, + "active": True, + "contractSize": 1.0, + "limits": {"amount": {"min": 0.001}}, + "precision": {"price": 2, "amount": 3}, + } + } + ) + sym, err = resolve_usdt_perp_symbol(ex, "ETH") + self.assertIsNone(err) + self.assertEqual(sym, "ETH/USDT:USDT") + + def test_amount_decimals_from_exchange(self): + ex = FakeExchange({}) + self.assertEqual(amount_decimals_from_exchange(ex, "ETH/USDT:USDT"), 3) + + def test_make_amount_precise_fn_from_market(self): + fn = make_amount_precise_fn_from_market({"amount_decimals": 3, "min_amount": 0.001}) + self.assertEqual(fn(1.23456), 1.234) + self.assertIsNone(fn(0.0001)) + + @patch.dict("os.environ", {"HUB_BRIDGE_TOKEN": "test-token"}, clear=False) + def test_hub_headers_use_x_hub_token(self): + from lib.hub.hub_calculator_market_lib import _hub_headers + + self.assertEqual(_hub_headers(), {"X-Hub-Token": "test-token"}) + + @patch("lib.hub.hub_calculator_market_lib.fetch_instance_market_sync") + def test_get_calculator_market_from_instance(self, fetch_mock): + fetch_mock.return_value = { + "ok": True, + "base": "ETH", + "exchange_symbol": "ETH/USDT:USDT", + "display_symbol": "ETH/USDT", + "contract_size": 0.01, + "price_tick": 0.01, + "price_decimals": 2, + "amount_decimals": 2, + "min_amount": 0.01, + } + ex = { + "id": "0", + "key": "binance", + "name": "币安 · crypto_monitor_binance", + "enabled": True, + "flask_url": "http://127.0.0.1:5001", + } + data, err = get_calculator_market("0", "ETH", ex=ex) + self.assertIsNone(err) + self.assertIsNotNone(data) + assert data is not None + self.assertEqual(data["exchange_id"], "0") + self.assertEqual(data["exchange_name"], "币安 · crypto_monitor_binance") + self.assertEqual(data["contract_size"], 0.01) + + @patch("lib.hub.hub_calculator_market_lib.enabled_exchanges") + def test_list_calculator_exchanges(self, enabled_mock): + enabled_mock.return_value = [ + {"id": "0", "key": "binance", "name": "币安", "enabled": True}, + ] + rows = list_calculator_exchanges() + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0]["id"], "0") + + def test_find_exchange_by_id(self): + with patch( + "lib.hub.hub_calculator_market_lib.load_settings", + return_value={"exchanges": [{"id": "2", "key": "gate", "name": "Gate"}]}, + ): + self.assertEqual(find_exchange("2")["key"], "gate") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_hub_entry_plan_lib.py b/tests/test_hub_entry_plan_lib.py index 63c98e4..ebecbdd 100644 --- a/tests/test_hub_entry_plan_lib.py +++ b/tests/test_hub_entry_plan_lib.py @@ -1,157 +1,157 @@ -"""开仓计划库:CRUD 与胜率统计。""" - -from __future__ import annotations - -import tempfile -from pathlib import Path - -from hub_entry_plan_lib import ( - compute_entry_plan_stats, - create_entry_plan, - delete_entry_plan, - init_db, - list_entry_plans, - normalize_plan_symbol, - resolve_stats_date_bounds, - update_entry_plan, -) - - -def _base_payload(**overrides): - data = { - "plan_date": "2026-06-14", - "exchange_key": "binance", - "symbol": "BTC", - "plan_type": "trend", - "trend_timeframe": "4h", - "entry_timeframe": "15m", - "direction": "long", - "target_level": "70000", - "current_range": "68000-69000", - "entry_scheme": "breakout", - "note": "test", - } - data.update(overrides) - return data - - -def test_normalize_plan_symbol(): - assert normalize_plan_symbol("btc") == "BTC/USDT" - assert normalize_plan_symbol("ETH/USDT") == "ETH/USDT" - - -def test_create_without_entry_scheme(): - with tempfile.TemporaryDirectory() as td: - db = Path(td) / "plans.db" - payload = _base_payload() - del payload["entry_scheme"] - row = create_entry_plan(payload, db_path=db) - assert row["entry_scheme"] == "" - assert row["entry_scheme_label"] == "待填写" - - -def test_archive_requires_entry_scheme(): - with tempfile.TemporaryDirectory() as td: - db = Path(td) / "plans.db" - payload = _base_payload() - del payload["entry_scheme"] - row = create_entry_plan(payload, db_path=db) - try: - update_entry_plan(int(row["id"]), {"result": "win"}, db_path=db) - assert False, "expected ValueError" - except ValueError as e: - assert "入场方案" in str(e) - updated = update_entry_plan( - int(row["id"]), - {"entry_scheme": "breakout", "result": "win"}, - db_path=db, - ) - assert updated["status"] == "archived" - - -def test_create_list_delete_active_plan(): - with tempfile.TemporaryDirectory() as td: - db = Path(td) / "plans.db" - row = create_entry_plan(_base_payload(), db_path=db) - assert row["status"] == "active" - assert row["symbol"] == "BTC/USDT" - active = list_entry_plans(status="active", db_path=db) - assert len(active) == 1 - assert delete_entry_plan(int(row["id"]), db_path=db) is True - assert list_entry_plans(status="active", db_path=db) == [] - - -def test_archive_on_result(): - with tempfile.TemporaryDirectory() as td: - db = Path(td) / "plans.db" - row = create_entry_plan(_base_payload(symbol="SOL"), db_path=db) - updated = update_entry_plan( - int(row["id"]), - {"result": "win", "pnl_amount": 12.5}, - db_path=db, - ) - assert updated["status"] == "archived" - assert updated["result"] == "win" - assert updated["pnl_amount"] == 12.5 - assert list_entry_plans(status="active", db_path=db) == [] - archived = list_entry_plans(status="archived", db_path=db) - assert len(archived) == 1 - - -def test_archive_without_pnl_amount(): - with tempfile.TemporaryDirectory() as td: - db = Path(td) / "plans.db" - row = create_entry_plan(_base_payload(symbol="DOGE"), db_path=db) - updated = update_entry_plan(int(row["id"]), {"result": "loss"}, db_path=db) - assert updated["status"] == "archived" - assert updated["pnl_amount"] is None - - -def test_cannot_delete_archived(): - with tempfile.TemporaryDirectory() as td: - db = Path(td) / "plans.db" - row = create_entry_plan(_base_payload(), db_path=db) - update_entry_plan(int(row["id"]), {"result": "win"}, db_path=db) - try: - delete_entry_plan(int(row["id"]), db_path=db) - assert False, "expected ValueError" - except ValueError as e: - assert "仅进行中" in str(e) - - -def test_compute_stats_by_symbol(): - with tempfile.TemporaryDirectory() as td: - db = Path(td) / "plans.db" - for sym, res in (("BTC", "win"), ("BTC", "loss"), ("ETH", "win")): - row = create_entry_plan(_base_payload(symbol=sym), db_path=db) - update_entry_plan(int(row["id"]), {"result": res}, db_path=db) - stats = compute_entry_plan_stats(dimension="symbol", period="all", db_path=db) - by_sym = {it["key"]: it for it in stats["items"]} - assert by_sym["BTC/USDT"]["win_count"] == 1 - assert by_sym["BTC/USDT"]["loss_count"] == 1 - assert by_sym["BTC/USDT"]["win_rate"] == 50.0 - assert by_sym["ETH/USDT"]["win_count"] == 1 - - -def test_stats_period_range_filter(): - with tempfile.TemporaryDirectory() as td: - db = Path(td) / "plans.db" - row1 = create_entry_plan(_base_payload(plan_date="2026-06-01"), db_path=db) - row2 = create_entry_plan(_base_payload(plan_date="2026-06-20", symbol="ETH"), db_path=db) - update_entry_plan(int(row1["id"]), {"result": "win"}, db_path=db) - update_entry_plan(int(row2["id"]), {"result": "loss"}, db_path=db) - stats = compute_entry_plan_stats( - dimension="symbol", - period="range", - date_from="2026-06-01", - date_to="2026-06-10", - db_path=db, - ) - assert len(stats["items"]) == 1 - assert stats["items"][0]["key"] == "BTC/USDT" - - -def test_resolve_stats_date_bounds(): - df, dt, label = resolve_stats_date_bounds(period="all") - assert df is None and dt is None - assert "全部" in label +"""开仓计划库:CRUD 与胜率统计。""" + +from __future__ import annotations + +import tempfile +from pathlib import Path + +from lib.hub.hub_entry_plan_lib import ( + compute_entry_plan_stats, + create_entry_plan, + delete_entry_plan, + init_db, + list_entry_plans, + normalize_plan_symbol, + resolve_stats_date_bounds, + update_entry_plan, +) + + +def _base_payload(**overrides): + data = { + "plan_date": "2026-06-14", + "exchange_key": "binance", + "symbol": "BTC", + "plan_type": "trend", + "trend_timeframe": "4h", + "entry_timeframe": "15m", + "direction": "long", + "target_level": "70000", + "current_range": "68000-69000", + "entry_scheme": "breakout", + "note": "test", + } + data.update(overrides) + return data + + +def test_normalize_plan_symbol(): + assert normalize_plan_symbol("btc") == "BTC/USDT" + assert normalize_plan_symbol("ETH/USDT") == "ETH/USDT" + + +def test_create_without_entry_scheme(): + with tempfile.TemporaryDirectory() as td: + db = Path(td) / "plans.db" + payload = _base_payload() + del payload["entry_scheme"] + row = create_entry_plan(payload, db_path=db) + assert row["entry_scheme"] == "" + assert row["entry_scheme_label"] == "待填写" + + +def test_archive_requires_entry_scheme(): + with tempfile.TemporaryDirectory() as td: + db = Path(td) / "plans.db" + payload = _base_payload() + del payload["entry_scheme"] + row = create_entry_plan(payload, db_path=db) + try: + update_entry_plan(int(row["id"]), {"result": "win"}, db_path=db) + assert False, "expected ValueError" + except ValueError as e: + assert "入场方案" in str(e) + updated = update_entry_plan( + int(row["id"]), + {"entry_scheme": "breakout", "result": "win"}, + db_path=db, + ) + assert updated["status"] == "archived" + + +def test_create_list_delete_active_plan(): + with tempfile.TemporaryDirectory() as td: + db = Path(td) / "plans.db" + row = create_entry_plan(_base_payload(), db_path=db) + assert row["status"] == "active" + assert row["symbol"] == "BTC/USDT" + active = list_entry_plans(status="active", db_path=db) + assert len(active) == 1 + assert delete_entry_plan(int(row["id"]), db_path=db) is True + assert list_entry_plans(status="active", db_path=db) == [] + + +def test_archive_on_result(): + with tempfile.TemporaryDirectory() as td: + db = Path(td) / "plans.db" + row = create_entry_plan(_base_payload(symbol="SOL"), db_path=db) + updated = update_entry_plan( + int(row["id"]), + {"result": "win", "pnl_amount": 12.5}, + db_path=db, + ) + assert updated["status"] == "archived" + assert updated["result"] == "win" + assert updated["pnl_amount"] == 12.5 + assert list_entry_plans(status="active", db_path=db) == [] + archived = list_entry_plans(status="archived", db_path=db) + assert len(archived) == 1 + + +def test_archive_without_pnl_amount(): + with tempfile.TemporaryDirectory() as td: + db = Path(td) / "plans.db" + row = create_entry_plan(_base_payload(symbol="DOGE"), db_path=db) + updated = update_entry_plan(int(row["id"]), {"result": "loss"}, db_path=db) + assert updated["status"] == "archived" + assert updated["pnl_amount"] is None + + +def test_cannot_delete_archived(): + with tempfile.TemporaryDirectory() as td: + db = Path(td) / "plans.db" + row = create_entry_plan(_base_payload(), db_path=db) + update_entry_plan(int(row["id"]), {"result": "win"}, db_path=db) + try: + delete_entry_plan(int(row["id"]), db_path=db) + assert False, "expected ValueError" + except ValueError as e: + assert "仅进行中" in str(e) + + +def test_compute_stats_by_symbol(): + with tempfile.TemporaryDirectory() as td: + db = Path(td) / "plans.db" + for sym, res in (("BTC", "win"), ("BTC", "loss"), ("ETH", "win")): + row = create_entry_plan(_base_payload(symbol=sym), db_path=db) + update_entry_plan(int(row["id"]), {"result": res}, db_path=db) + stats = compute_entry_plan_stats(dimension="symbol", period="all", db_path=db) + by_sym = {it["key"]: it for it in stats["items"]} + assert by_sym["BTC/USDT"]["win_count"] == 1 + assert by_sym["BTC/USDT"]["loss_count"] == 1 + assert by_sym["BTC/USDT"]["win_rate"] == 50.0 + assert by_sym["ETH/USDT"]["win_count"] == 1 + + +def test_stats_period_range_filter(): + with tempfile.TemporaryDirectory() as td: + db = Path(td) / "plans.db" + row1 = create_entry_plan(_base_payload(plan_date="2026-06-01"), db_path=db) + row2 = create_entry_plan(_base_payload(plan_date="2026-06-20", symbol="ETH"), db_path=db) + update_entry_plan(int(row1["id"]), {"result": "win"}, db_path=db) + update_entry_plan(int(row2["id"]), {"result": "loss"}, db_path=db) + stats = compute_entry_plan_stats( + dimension="symbol", + period="range", + date_from="2026-06-01", + date_to="2026-06-10", + db_path=db, + ) + assert len(stats["items"]) == 1 + assert stats["items"][0]["key"] == "BTC/USDT" + + +def test_resolve_stats_date_bounds(): + df, dt, label = resolve_stats_date_bounds(period="all") + assert df is None and dt is None + assert "全部" in label diff --git a/tests/test_hub_fund_history_lib.py b/tests/test_hub_fund_history_lib.py index e779162..de44015 100644 --- a/tests/test_hub_fund_history_lib.py +++ b/tests/test_hub_fund_history_lib.py @@ -1,113 +1,113 @@ -"""hub_fund_history_lib:总资金、回撤与日快照。""" -from __future__ import annotations - -from hub_fund_history_lib import ( - account_total_usdt, - build_fund_overview, - compute_drawdown, - get_fund_history, - record_fund_snapshot, -) - - -def test_account_total_requires_both_sides(): - assert account_total_usdt(10, 20) == 30.0 - assert account_total_usdt(10, None) is None - assert account_total_usdt(None, 5) is None - - -def test_compute_drawdown(): - dd = compute_drawdown([100, 120, 90, 110]) - assert dd["peak_usdt"] == 120.0 - assert dd["max_drawdown_u"] == 30.0 - assert dd["max_drawdown_pct"] == 25.0 - - -def test_build_fund_overview_skips_unmonitored(tmp_path, monkeypatch): - hist_path = tmp_path / "hub_fund_history.json" - monkeypatch.setattr("hub_fund_history_lib.FUND_HISTORY_PATH", hist_path) - record_fund_snapshot( - "2026-06-01", - [ - { - "key": "binance", - "name": "Binance", - "funding_usdt": 10, - "trading_usdt": 20, - "monitored": True, - } - ], - keep_days=180, - ) - record_fund_snapshot( - "2026-06-02", - [ - { - "key": "binance", - "name": "Binance", - "funding_usdt": 12, - "trading_usdt": 18, - "monitored": True, - } - ], - keep_days=180, - ) - exchanges = [ - {"id": "0", "key": "binance", "name": "Binance", "enabled": True}, - {"id": "3", "key": "gate_bot", "name": "Gate Bot", "enabled": False}, - ] - board_rows = [ - { - "key": "binance", - "name": "Binance", - "account_ok": True, - "funding_usdt": 15, - "trading_usdt": 25, - } - ] - out = build_fund_overview( - exchanges, - board_rows=board_rows, - trading_day="2026-06-02", - keep_days=180, - ) - assert out["totals"]["total_usdt"] == 40.0 - assert out["totals"]["monitored_count"] == 1 - assert len(out["accounts"]) == 1 - assert all(a["monitored"] for a in out["accounts"]) - assert out["totals"]["drawdown"]["max_drawdown_u"] == 0.0 - - -def test_history_start_day_filters_older(tmp_path, monkeypatch): - hist_path = tmp_path / "hub_fund_history.json" - monkeypatch.setattr("hub_fund_history_lib.FUND_HISTORY_PATH", hist_path) - monkeypatch.setattr("hub_fund_history_lib.FUND_HISTORY_START_DAY", "2026-06-09") - record_fund_snapshot( - "2026-06-01", - [ - { - "key": "binance", - "name": "Binance", - "funding_usdt": 1, - "trading_usdt": 1, - "monitored": True, - } - ], - keep_days=180, - ) - record_fund_snapshot( - "2026-06-09", - [ - { - "key": "binance", - "name": "Binance", - "funding_usdt": 10, - "trading_usdt": 20, - "monitored": True, - } - ], - keep_days=180, - ) - hist = get_fund_history(anchor_day="2026-06-10", keep_days=180) - assert "2026-06-01" not in hist - assert "2026-06-09" in hist +"""hub_fund_history_lib:总资金、回撤与日快照。""" +from __future__ import annotations + +from lib.hub.hub_fund_history_lib import ( + account_total_usdt, + build_fund_overview, + compute_drawdown, + get_fund_history, + record_fund_snapshot, +) + + +def test_account_total_requires_both_sides(): + assert account_total_usdt(10, 20) == 30.0 + assert account_total_usdt(10, None) is None + assert account_total_usdt(None, 5) is None + + +def test_compute_drawdown(): + dd = compute_drawdown([100, 120, 90, 110]) + assert dd["peak_usdt"] == 120.0 + assert dd["max_drawdown_u"] == 30.0 + assert dd["max_drawdown_pct"] == 25.0 + + +def test_build_fund_overview_skips_unmonitored(tmp_path, monkeypatch): + hist_path = tmp_path / "hub_fund_history.json" + monkeypatch.setattr("hub_fund_history_lib.FUND_HISTORY_PATH", hist_path) + record_fund_snapshot( + "2026-06-01", + [ + { + "key": "binance", + "name": "Binance", + "funding_usdt": 10, + "trading_usdt": 20, + "monitored": True, + } + ], + keep_days=180, + ) + record_fund_snapshot( + "2026-06-02", + [ + { + "key": "binance", + "name": "Binance", + "funding_usdt": 12, + "trading_usdt": 18, + "monitored": True, + } + ], + keep_days=180, + ) + exchanges = [ + {"id": "0", "key": "binance", "name": "Binance", "enabled": True}, + {"id": "3", "key": "gate_bot", "name": "Gate Bot", "enabled": False}, + ] + board_rows = [ + { + "key": "binance", + "name": "Binance", + "account_ok": True, + "funding_usdt": 15, + "trading_usdt": 25, + } + ] + out = build_fund_overview( + exchanges, + board_rows=board_rows, + trading_day="2026-06-02", + keep_days=180, + ) + assert out["totals"]["total_usdt"] == 40.0 + assert out["totals"]["monitored_count"] == 1 + assert len(out["accounts"]) == 1 + assert all(a["monitored"] for a in out["accounts"]) + assert out["totals"]["drawdown"]["max_drawdown_u"] == 0.0 + + +def test_history_start_day_filters_older(tmp_path, monkeypatch): + hist_path = tmp_path / "hub_fund_history.json" + monkeypatch.setattr("hub_fund_history_lib.FUND_HISTORY_PATH", hist_path) + monkeypatch.setattr("hub_fund_history_lib.FUND_HISTORY_START_DAY", "2026-06-09") + record_fund_snapshot( + "2026-06-01", + [ + { + "key": "binance", + "name": "Binance", + "funding_usdt": 1, + "trading_usdt": 1, + "monitored": True, + } + ], + keep_days=180, + ) + record_fund_snapshot( + "2026-06-09", + [ + { + "key": "binance", + "name": "Binance", + "funding_usdt": 10, + "trading_usdt": 20, + "monitored": True, + } + ], + keep_days=180, + ) + hist = get_fund_history(anchor_day="2026-06-10", keep_days=180) + assert "2026-06-01" not in hist + assert "2026-06-09" in hist diff --git a/tests/test_hub_host_status_lib.py b/tests/test_hub_host_status_lib.py index cb4c640..1b9cf07 100644 --- a/tests/test_hub_host_status_lib.py +++ b/tests/test_hub_host_status_lib.py @@ -1,58 +1,58 @@ -"""hub_host_status_lib 单元测试。""" -from __future__ import annotations - -import sys -import unittest -from unittest.mock import MagicMock, patch - -from hub_host_status_lib import _disk_path, _state, get_host_status - - -class HubHostStatusLibTest(unittest.TestCase): - def setUp(self): - _state["primed"] = False - _state["net_ts"] = 0.0 - _state["net_sent"] = 0 - _state["net_recv"] = 0 - - def test_disk_path_env_override(self): - with patch.dict("os.environ", {"HUB_HOST_DISK_PATH": "/data"}, clear=False): - self.assertEqual(_disk_path(), "/data") - - def test_get_host_status_without_psutil(self): - import builtins - - real_import = builtins.__import__ - - def fake_import(name, globals=None, locals=None, fromlist=(), level=0): - if name == "psutil": - raise ImportError("no psutil") - return real_import(name, globals, locals, fromlist, level) - - with patch("builtins.__import__", side_effect=fake_import): - out = get_host_status() - self.assertFalse(out.get("ok")) - self.assertIn("psutil", out.get("msg", "")) - - def test_get_host_status_payload(self): - fake_vm = MagicMock(total=8_000_000_000, used=3_200_000_000, percent=40.0) - fake_du = MagicMock(total=100_000_000_000, used=50_000_000_000) - fake_net = MagicMock(bytes_sent=1_000_000, bytes_recv=2_000_000) - fake_psutil = MagicMock() - fake_psutil.cpu_percent.return_value = 12.5 - fake_psutil.cpu_count.return_value = 4 - fake_psutil.virtual_memory.return_value = fake_vm - fake_psutil.disk_usage.return_value = fake_du - fake_psutil.net_io_counters.return_value = fake_net - fake_psutil.boot_time.return_value = 1_700_000_000.0 - with patch.dict(sys.modules, {"psutil": fake_psutil}): - out = get_host_status() - self.assertTrue(out.get("ok")) - self.assertEqual(out["cpu"]["percent"], 12.5) - self.assertEqual(out["memory"]["percent"], 40.0) - self.assertEqual(out["disk"]["percent"], 50.0) - self.assertIn("network", out) - - -if __name__ == "__main__": - unittest.main() +"""hub_host_status_lib 单元测试。""" +from __future__ import annotations + +import sys +import unittest +from unittest.mock import MagicMock, patch + +from lib.hub.hub_host_status_lib import _disk_path, _state, get_host_status + + +class HubHostStatusLibTest(unittest.TestCase): + def setUp(self): + _state["primed"] = False + _state["net_ts"] = 0.0 + _state["net_sent"] = 0 + _state["net_recv"] = 0 + + def test_disk_path_env_override(self): + with patch.dict("os.environ", {"HUB_HOST_DISK_PATH": "/data"}, clear=False): + self.assertEqual(_disk_path(), "/data") + + def test_get_host_status_without_psutil(self): + import builtins + + real_import = builtins.__import__ + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): + if name == "psutil": + raise ImportError("no psutil") + return real_import(name, globals, locals, fromlist, level) + + with patch("builtins.__import__", side_effect=fake_import): + out = get_host_status() + self.assertFalse(out.get("ok")) + self.assertIn("psutil", out.get("msg", "")) + + def test_get_host_status_payload(self): + fake_vm = MagicMock(total=8_000_000_000, used=3_200_000_000, percent=40.0) + fake_du = MagicMock(total=100_000_000_000, used=50_000_000_000) + fake_net = MagicMock(bytes_sent=1_000_000, bytes_recv=2_000_000) + fake_psutil = MagicMock() + fake_psutil.cpu_percent.return_value = 12.5 + fake_psutil.cpu_count.return_value = 4 + fake_psutil.virtual_memory.return_value = fake_vm + fake_psutil.disk_usage.return_value = fake_du + fake_psutil.net_io_counters.return_value = fake_net + fake_psutil.boot_time.return_value = 1_700_000_000.0 + with patch.dict(sys.modules, {"psutil": fake_psutil}): + out = get_host_status() + self.assertTrue(out.get("ok")) + self.assertEqual(out["cpu"]["percent"], 12.5) + self.assertEqual(out["memory"]["percent"], 40.0) + self.assertEqual(out["disk"]["percent"], 50.0) + self.assertIn("network", out) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_hub_kline_store.py b/tests/test_hub_kline_store.py index 6e97707..33245d4 100644 --- a/tests/test_hub_kline_store.py +++ b/tests/test_hub_kline_store.py @@ -1,466 +1,466 @@ -"""中控 K 线库:分周期保留、聚合与分页读取。""" -from __future__ import annotations - -import tempfile -import time -import unittest -from pathlib import Path - -from hub_kline_store import ( - HUB_KLINE_REMOTE_FETCH_CAP, - _since_ms_for_span, - clear_series_bars, - init_db, - load_bars_before, - load_bars_latest, - purge_retention, - purge_timeframe_by_days, - resolve_chart_bars, - retention_days, - trim_contiguous_tail, - upsert_bars, -) -from hub_ohlcv_lib import ( - TIMEFRAME_MS, - bar_limit_for_timeframe, - chart_fetch_start_ms, - chart_initial_limit, - last_closed_bar_open_ms, - window_start_ms, -) - - -class TestHubKlineStore(unittest.TestCase): - def setUp(self): - self.tmp = tempfile.TemporaryDirectory() - self.db = Path(self.tmp.name) / "test_hub_kline.db" - - def tearDown(self): - self.tmp.cleanup() - - def test_bar_limits(self): - self.assertEqual(bar_limit_for_timeframe("5m"), 5000) - self.assertEqual(bar_limit_for_timeframe("1h"), 1000) - self.assertEqual(bar_limit_for_timeframe("1d"), 1000) - self.assertEqual(bar_limit_for_timeframe("1w"), 500) - self.assertEqual(chart_initial_limit("5m"), 2000) - self.assertEqual(chart_initial_limit("1h"), 1000) - self.assertEqual(chart_initial_limit("1d"), 500) - - def test_chart_fetch_window_exceeds_retention(self): - now = int(time.time() * 1000) - need = bar_limit_for_timeframe("1d") - fetch_start = chart_fetch_start_ms("1d", need, now) - db_start = window_start_ms("1d", need, retention_days(), now) - self.assertLess(fetch_start, db_start) - - def test_purge_retention_5m_one_year(self): - init_db(self.db) - old_ms = int(time.time() * 1000) - 400 * 86400000 - upsert_bars( - "okx", - "BTC/USDT", - "5m", - [ - { - "open_time_ms": old_ms, - "open": 1, - "high": 2, - "low": 0.5, - "close": 1.5, - "volume": 10, - } - ], - self.db, - ) - n = purge_timeframe_by_days("5m", 365, self.db) - self.assertGreaterEqual(n, 1) - rows = load_bars_latest("okx", "BTC/USDT", "5m", 10, self.db) - self.assertEqual(len(rows), 0) - - def test_purge_retention_keeps_1d(self): - init_db(self.db) - old_ms = int(time.time() * 1000) - 400 * 86400000 - upsert_bars( - "okx", - "BTC/USDT", - "1d", - [ - { - "open_time_ms": old_ms, - "open": 1, - "high": 2, - "low": 0.5, - "close": 1.5, - "volume": 10, - } - ], - self.db, - ) - purge_retention(self.db) - rows = load_bars_latest("okx", "BTC/USDT", "1d", 10, self.db) - self.assertEqual(len(rows), 1) - - def test_resolve_uses_cache_without_remote(self): - init_db(self.db) - now = int(time.time() * 1000) - tf = "5m" - period = TIMEFRAME_MS[tf] - last_closed = last_closed_bar_open_ms(tf, now) - bars = [] - for i in range(400): - oms = last_closed - (399 - i) * period - bars.append( - { - "open_time_ms": oms, - "open": 100 + i, - "high": 101 + i, - "low": 99 + i, - "close": 100.5 + i, - "volume": 1000 + i, - } - ) - upsert_bars("okx", "ETH/USDT", tf, bars, self.db) - - def remote_fetch(**kwargs): - self.fail("不应请求交易所") - - out = resolve_chart_bars( - "okx", - "ETH/USDT", - tf, - remote_fetch, - db_path=self.db, - limit=300, - ) - self.assertTrue(out.get("ok")) - self.assertEqual(len(out.get("candles") or []), 300) - - def test_resolve_15m_reads_native_bars(self): - init_db(self.db) - now = int(time.time() * 1000) - period = TIMEFRAME_MS["15m"] - last_closed = last_closed_bar_open_ms("15m", now) - bars = [] - for i in range(12): - oms = last_closed - (11 - i) * period - bars.append( - { - "open_time_ms": oms, - "open": 1.0 + i, - "high": 2.0 + i, - "low": 0.5 + i, - "close": 1.5 + i, - "volume": 10.0, - } - ) - upsert_bars("okx", "ETH/USDT", "15m", bars, self.db) - - def remote_fetch(**kwargs): - self.fail("不应请求交易所") - - out = resolve_chart_bars( - "okx", - "ETH/USDT", - "15m", - remote_fetch, - db_path=self.db, - limit=10, - ) - self.assertTrue(out.get("ok")) - self.assertEqual(out.get("source"), "db") - self.assertEqual(out.get("storage_timeframe"), "15m") - self.assertGreaterEqual(len(out.get("candles") or []), 10) - - def test_load_bars_before(self): - init_db(self.db) - period = TIMEFRAME_MS["1h"] - base = 1_700_000_000_000 - bars = [] - for i in range(5): - bars.append( - { - "open_time_ms": base + i * period, - "open": 1, - "high": 2, - "low": 0.5, - "close": 1.5, - "volume": 1, - } - ) - upsert_bars("okx", "BTC/USDT", "1h", bars, self.db) - before = base + 3 * period - got = load_bars_before("okx", "BTC/USDT", "1h", before, 2, self.db) - self.assertEqual(len(got), 2) - self.assertEqual(got[-1]["open_time_ms"], base + 2 * period) - - def test_trim_contiguous_tail_drops_orphan_prefix(self): - period = TIMEFRAME_MS["15m"] - base_old = 1_700_000_000_000 - base_new = base_old + period * 500 - bars = [] - for i in range(3): - bars.append( - { - "open_time_ms": base_old + i * period, - "open": 1, - "high": 2, - "low": 0.5, - "close": 1.5, - "volume": 1, - } - ) - for i in range(5): - bars.append( - { - "open_time_ms": base_new + i * period, - "open": 2, - "high": 3, - "low": 1.5, - "close": 2.5, - "volume": 2, - } - ) - trimmed, split = trim_contiguous_tail(bars, period) - self.assertEqual(split, 3) - self.assertEqual(len(trimmed), 5) - self.assertEqual(trimmed[0]["open_time_ms"], base_new) - - def test_resolve_drops_discontinuous_orphans(self): - init_db(self.db) - period = TIMEFRAME_MS["15m"] - now = int(time.time() * 1000) - old_ms = now - period * 800 - upsert_bars( - "okx", - "ONDO/USDT", - "15m", - [ - { - "open_time_ms": old_ms, - "open": 0.33, - "high": 0.34, - "low": 0.32, - "close": 0.335, - "volume": 100, - } - ], - self.db, - ) - recent = [] - start = now - period * 20 - for i in range(20): - recent.append( - { - "open_time_ms": start + i * period, - "open": 0.35, - "high": 0.36, - "low": 0.34, - "close": 0.355, - "volume": 50, - } - ) - - def remote_fetch(**kwargs): - return {"ok": True, "bars": recent, "price_tick": 0.0001} - - out = resolve_chart_bars( - "okx", - "ONDO/USDT", - "15m", - remote_fetch, - db_path=self.db, - limit=50, - ) - self.assertTrue(out.get("ok")) - candles = out.get("candles") or [] - self.assertGreaterEqual(len(candles), 19) - if len(candles) >= 2: - for i in range(1, len(candles)): - gap = candles[i]["time"] - candles[i - 1]["time"] - self.assertLessEqual(gap, int(period / 1000 * 3.0)) - - def test_resolve_refetches_when_db_has_discontinuous_full_count(self): - init_db(self.db) - period = TIMEFRAME_MS["15m"] - now = int(time.time() * 1000) - old_start = now - period * 3000 - recent_start = now - period * 25 - old_bars = [ - { - "open_time_ms": old_start + i * period, - "open": 62000, - "high": 62100, - "low": 61900, - "close": 62050, - "volume": 10, - } - for i in range(500) - ] - recent = [ - { - "open_time_ms": recent_start + i * period, - "open": 104000, - "high": 104100, - "low": 103900, - "close": 104050, - "volume": 20, - } - for i in range(30) - ] - upsert_bars("binance", "BTC/USDT", "15m", old_bars, self.db) - upsert_bars("binance", "BTC/USDT", "15m", recent, self.db) - fetch_calls = [] - - def remote_fetch(**kwargs): - fetch_calls.append(dict(kwargs)) - full = [] - start = now - period * 120 - for i in range(120): - full.append( - { - "open_time_ms": start + i * period, - "open": 104000 + i, - "high": 104100 + i, - "low": 103900 + i, - "close": 104050 + i, - "volume": 30, - } - ) - return {"ok": True, "bars": full, "price_tick": 0.01} - - out = resolve_chart_bars( - "binance", - "BTC/USDT", - "15m", - remote_fetch, - db_path=self.db, - limit=2000, - ) - self.assertTrue(out.get("ok")) - self.assertGreater(len(fetch_calls), 0) - self.assertGreaterEqual(len(out.get("candles") or []), 100) - self.assertGreater(int(out.get("fetched") or 0), 0) - - def test_clear_series_and_force_refetch(self): - init_db(self.db) - period = TIMEFRAME_MS["5m"] - now = int(time.time() * 1000) - stale = [ - { - "open_time_ms": now - period * (i + 100), - "open": 1, - "high": 2, - "low": 0.5, - "close": 1.5, - "volume": 1, - } - for i in range(40) - ] - upsert_bars("binance", "BTC/USDT", "5m", stale, self.db) - self.assertEqual(len(load_bars_latest("binance", "BTC/USDT", "5m", 100, self.db)), 40) - removed = clear_series_bars("binance", "BTC/USDT", "5m", self.db) - self.assertEqual(removed, 40) - self.assertEqual(len(load_bars_latest("binance", "BTC/USDT", "5m", 100, self.db)), 0) - - fresh = [ - { - "open_time_ms": now - period * (20 - i), - "open": 10, - "high": 11, - "low": 9, - "close": 10.5, - "volume": 2, - } - for i in range(20) - ] - - def remote_fetch(**kwargs): - return {"ok": True, "bars": fresh, "price_tick": 0.01} - - out = resolve_chart_bars( - "binance", - "BTC/USDT", - "5m", - remote_fetch, - db_path=self.db, - force_refresh=True, - clear_db=True, - limit=50, - ) - self.assertTrue(out.get("ok")) - self.assertGreaterEqual(int(out.get("cleared") or 0), 0) - self.assertGreater(int(out.get("fetched") or 0), 0) - self.assertGreaterEqual(len(out.get("candles") or []), 19) - - def test_since_span_matches_fetch_limit_not_need(self): - period = TIMEFRAME_MS["15m"] - now_ms = 1_800_000_000_000 - fetch_limit = HUB_KLINE_REMOTE_FETCH_CAP - since = _since_ms_for_span( - now_ms=now_ms, - period_ms=period, - span_bars=fetch_limit, - cutoff_ms=0, - ) - self.assertEqual(since, now_ms - period * fetch_limit) - wrong_since = now_ms - period * chart_initial_limit("15m") - self.assertGreater(since, wrong_since) - - def test_thin_series_tail_refresh_fetches_full_window(self): - init_db(self.db) - period = TIMEFRAME_MS["15m"] - now = int(time.time() * 1000) - last_closed = last_closed_bar_open_ms("15m", now) - bars = [ - { - "open_time_ms": last_closed - period * (150 - i), - "open": 100000, - "high": 100100, - "low": 99900, - "close": 100050, - "volume": 1, - } - for i in range(150) - ] - fetch_calls: list[dict] = [] - - def remote_fetch(**kwargs): - fetch_calls.append(dict(kwargs)) - return {"ok": True, "bars": bars, "price_tick": 0.01} - - out = resolve_chart_bars( - "binance", - "BTC/USDT", - "15m", - remote_fetch, - db_path=self.db, - tail_refresh=True, - ) - self.assertTrue(out.get("ok")) - self.assertGreaterEqual(len(out.get("candles") or []), 100) - self.assertGreater(int(out.get("fetched") or 0), 0) - self.assertTrue(any(int(c.get("limit") or 0) > 30 for c in fetch_calls)) - - def test_resolve_before_ms_exhausted(self): - init_db(self.db) - - def remote_fetch(**kwargs): - return {"ok": False, "msg": "no remote"} - - out = resolve_chart_bars( - "okx", - "BTC/USDT", - "5m", - remote_fetch, - db_path=self.db, - limit=100, - before_ms=int(time.time() * 1000), - ) - self.assertTrue(out.get("ok")) - self.assertEqual(out.get("candles"), []) - self.assertTrue(out.get("exhausted")) - - -if __name__ == "__main__": - unittest.main() +"""中控 K 线库:分周期保留、聚合与分页读取。""" +from __future__ import annotations + +import tempfile +import time +import unittest +from pathlib import Path + +from lib.hub.hub_kline_store import ( + HUB_KLINE_REMOTE_FETCH_CAP, + _since_ms_for_span, + clear_series_bars, + init_db, + load_bars_before, + load_bars_latest, + purge_retention, + purge_timeframe_by_days, + resolve_chart_bars, + retention_days, + trim_contiguous_tail, + upsert_bars, +) +from lib.hub.hub_ohlcv_lib import ( + TIMEFRAME_MS, + bar_limit_for_timeframe, + chart_fetch_start_ms, + chart_initial_limit, + last_closed_bar_open_ms, + window_start_ms, +) + + +class TestHubKlineStore(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.TemporaryDirectory() + self.db = Path(self.tmp.name) / "test_hub_kline.db" + + def tearDown(self): + self.tmp.cleanup() + + def test_bar_limits(self): + self.assertEqual(bar_limit_for_timeframe("5m"), 5000) + self.assertEqual(bar_limit_for_timeframe("1h"), 1000) + self.assertEqual(bar_limit_for_timeframe("1d"), 1000) + self.assertEqual(bar_limit_for_timeframe("1w"), 500) + self.assertEqual(chart_initial_limit("5m"), 2000) + self.assertEqual(chart_initial_limit("1h"), 1000) + self.assertEqual(chart_initial_limit("1d"), 500) + + def test_chart_fetch_window_exceeds_retention(self): + now = int(time.time() * 1000) + need = bar_limit_for_timeframe("1d") + fetch_start = chart_fetch_start_ms("1d", need, now) + db_start = window_start_ms("1d", need, retention_days(), now) + self.assertLess(fetch_start, db_start) + + def test_purge_retention_5m_one_year(self): + init_db(self.db) + old_ms = int(time.time() * 1000) - 400 * 86400000 + upsert_bars( + "okx", + "BTC/USDT", + "5m", + [ + { + "open_time_ms": old_ms, + "open": 1, + "high": 2, + "low": 0.5, + "close": 1.5, + "volume": 10, + } + ], + self.db, + ) + n = purge_timeframe_by_days("5m", 365, self.db) + self.assertGreaterEqual(n, 1) + rows = load_bars_latest("okx", "BTC/USDT", "5m", 10, self.db) + self.assertEqual(len(rows), 0) + + def test_purge_retention_keeps_1d(self): + init_db(self.db) + old_ms = int(time.time() * 1000) - 400 * 86400000 + upsert_bars( + "okx", + "BTC/USDT", + "1d", + [ + { + "open_time_ms": old_ms, + "open": 1, + "high": 2, + "low": 0.5, + "close": 1.5, + "volume": 10, + } + ], + self.db, + ) + purge_retention(self.db) + rows = load_bars_latest("okx", "BTC/USDT", "1d", 10, self.db) + self.assertEqual(len(rows), 1) + + def test_resolve_uses_cache_without_remote(self): + init_db(self.db) + now = int(time.time() * 1000) + tf = "5m" + period = TIMEFRAME_MS[tf] + last_closed = last_closed_bar_open_ms(tf, now) + bars = [] + for i in range(400): + oms = last_closed - (399 - i) * period + bars.append( + { + "open_time_ms": oms, + "open": 100 + i, + "high": 101 + i, + "low": 99 + i, + "close": 100.5 + i, + "volume": 1000 + i, + } + ) + upsert_bars("okx", "ETH/USDT", tf, bars, self.db) + + def remote_fetch(**kwargs): + self.fail("不应请求交易所") + + out = resolve_chart_bars( + "okx", + "ETH/USDT", + tf, + remote_fetch, + db_path=self.db, + limit=300, + ) + self.assertTrue(out.get("ok")) + self.assertEqual(len(out.get("candles") or []), 300) + + def test_resolve_15m_reads_native_bars(self): + init_db(self.db) + now = int(time.time() * 1000) + period = TIMEFRAME_MS["15m"] + last_closed = last_closed_bar_open_ms("15m", now) + bars = [] + for i in range(12): + oms = last_closed - (11 - i) * period + bars.append( + { + "open_time_ms": oms, + "open": 1.0 + i, + "high": 2.0 + i, + "low": 0.5 + i, + "close": 1.5 + i, + "volume": 10.0, + } + ) + upsert_bars("okx", "ETH/USDT", "15m", bars, self.db) + + def remote_fetch(**kwargs): + self.fail("不应请求交易所") + + out = resolve_chart_bars( + "okx", + "ETH/USDT", + "15m", + remote_fetch, + db_path=self.db, + limit=10, + ) + self.assertTrue(out.get("ok")) + self.assertEqual(out.get("source"), "db") + self.assertEqual(out.get("storage_timeframe"), "15m") + self.assertGreaterEqual(len(out.get("candles") or []), 10) + + def test_load_bars_before(self): + init_db(self.db) + period = TIMEFRAME_MS["1h"] + base = 1_700_000_000_000 + bars = [] + for i in range(5): + bars.append( + { + "open_time_ms": base + i * period, + "open": 1, + "high": 2, + "low": 0.5, + "close": 1.5, + "volume": 1, + } + ) + upsert_bars("okx", "BTC/USDT", "1h", bars, self.db) + before = base + 3 * period + got = load_bars_before("okx", "BTC/USDT", "1h", before, 2, self.db) + self.assertEqual(len(got), 2) + self.assertEqual(got[-1]["open_time_ms"], base + 2 * period) + + def test_trim_contiguous_tail_drops_orphan_prefix(self): + period = TIMEFRAME_MS["15m"] + base_old = 1_700_000_000_000 + base_new = base_old + period * 500 + bars = [] + for i in range(3): + bars.append( + { + "open_time_ms": base_old + i * period, + "open": 1, + "high": 2, + "low": 0.5, + "close": 1.5, + "volume": 1, + } + ) + for i in range(5): + bars.append( + { + "open_time_ms": base_new + i * period, + "open": 2, + "high": 3, + "low": 1.5, + "close": 2.5, + "volume": 2, + } + ) + trimmed, split = trim_contiguous_tail(bars, period) + self.assertEqual(split, 3) + self.assertEqual(len(trimmed), 5) + self.assertEqual(trimmed[0]["open_time_ms"], base_new) + + def test_resolve_drops_discontinuous_orphans(self): + init_db(self.db) + period = TIMEFRAME_MS["15m"] + now = int(time.time() * 1000) + old_ms = now - period * 800 + upsert_bars( + "okx", + "ONDO/USDT", + "15m", + [ + { + "open_time_ms": old_ms, + "open": 0.33, + "high": 0.34, + "low": 0.32, + "close": 0.335, + "volume": 100, + } + ], + self.db, + ) + recent = [] + start = now - period * 20 + for i in range(20): + recent.append( + { + "open_time_ms": start + i * period, + "open": 0.35, + "high": 0.36, + "low": 0.34, + "close": 0.355, + "volume": 50, + } + ) + + def remote_fetch(**kwargs): + return {"ok": True, "bars": recent, "price_tick": 0.0001} + + out = resolve_chart_bars( + "okx", + "ONDO/USDT", + "15m", + remote_fetch, + db_path=self.db, + limit=50, + ) + self.assertTrue(out.get("ok")) + candles = out.get("candles") or [] + self.assertGreaterEqual(len(candles), 19) + if len(candles) >= 2: + for i in range(1, len(candles)): + gap = candles[i]["time"] - candles[i - 1]["time"] + self.assertLessEqual(gap, int(period / 1000 * 3.0)) + + def test_resolve_refetches_when_db_has_discontinuous_full_count(self): + init_db(self.db) + period = TIMEFRAME_MS["15m"] + now = int(time.time() * 1000) + old_start = now - period * 3000 + recent_start = now - period * 25 + old_bars = [ + { + "open_time_ms": old_start + i * period, + "open": 62000, + "high": 62100, + "low": 61900, + "close": 62050, + "volume": 10, + } + for i in range(500) + ] + recent = [ + { + "open_time_ms": recent_start + i * period, + "open": 104000, + "high": 104100, + "low": 103900, + "close": 104050, + "volume": 20, + } + for i in range(30) + ] + upsert_bars("binance", "BTC/USDT", "15m", old_bars, self.db) + upsert_bars("binance", "BTC/USDT", "15m", recent, self.db) + fetch_calls = [] + + def remote_fetch(**kwargs): + fetch_calls.append(dict(kwargs)) + full = [] + start = now - period * 120 + for i in range(120): + full.append( + { + "open_time_ms": start + i * period, + "open": 104000 + i, + "high": 104100 + i, + "low": 103900 + i, + "close": 104050 + i, + "volume": 30, + } + ) + return {"ok": True, "bars": full, "price_tick": 0.01} + + out = resolve_chart_bars( + "binance", + "BTC/USDT", + "15m", + remote_fetch, + db_path=self.db, + limit=2000, + ) + self.assertTrue(out.get("ok")) + self.assertGreater(len(fetch_calls), 0) + self.assertGreaterEqual(len(out.get("candles") or []), 100) + self.assertGreater(int(out.get("fetched") or 0), 0) + + def test_clear_series_and_force_refetch(self): + init_db(self.db) + period = TIMEFRAME_MS["5m"] + now = int(time.time() * 1000) + stale = [ + { + "open_time_ms": now - period * (i + 100), + "open": 1, + "high": 2, + "low": 0.5, + "close": 1.5, + "volume": 1, + } + for i in range(40) + ] + upsert_bars("binance", "BTC/USDT", "5m", stale, self.db) + self.assertEqual(len(load_bars_latest("binance", "BTC/USDT", "5m", 100, self.db)), 40) + removed = clear_series_bars("binance", "BTC/USDT", "5m", self.db) + self.assertEqual(removed, 40) + self.assertEqual(len(load_bars_latest("binance", "BTC/USDT", "5m", 100, self.db)), 0) + + fresh = [ + { + "open_time_ms": now - period * (20 - i), + "open": 10, + "high": 11, + "low": 9, + "close": 10.5, + "volume": 2, + } + for i in range(20) + ] + + def remote_fetch(**kwargs): + return {"ok": True, "bars": fresh, "price_tick": 0.01} + + out = resolve_chart_bars( + "binance", + "BTC/USDT", + "5m", + remote_fetch, + db_path=self.db, + force_refresh=True, + clear_db=True, + limit=50, + ) + self.assertTrue(out.get("ok")) + self.assertGreaterEqual(int(out.get("cleared") or 0), 0) + self.assertGreater(int(out.get("fetched") or 0), 0) + self.assertGreaterEqual(len(out.get("candles") or []), 19) + + def test_since_span_matches_fetch_limit_not_need(self): + period = TIMEFRAME_MS["15m"] + now_ms = 1_800_000_000_000 + fetch_limit = HUB_KLINE_REMOTE_FETCH_CAP + since = _since_ms_for_span( + now_ms=now_ms, + period_ms=period, + span_bars=fetch_limit, + cutoff_ms=0, + ) + self.assertEqual(since, now_ms - period * fetch_limit) + wrong_since = now_ms - period * chart_initial_limit("15m") + self.assertGreater(since, wrong_since) + + def test_thin_series_tail_refresh_fetches_full_window(self): + init_db(self.db) + period = TIMEFRAME_MS["15m"] + now = int(time.time() * 1000) + last_closed = last_closed_bar_open_ms("15m", now) + bars = [ + { + "open_time_ms": last_closed - period * (150 - i), + "open": 100000, + "high": 100100, + "low": 99900, + "close": 100050, + "volume": 1, + } + for i in range(150) + ] + fetch_calls: list[dict] = [] + + def remote_fetch(**kwargs): + fetch_calls.append(dict(kwargs)) + return {"ok": True, "bars": bars, "price_tick": 0.01} + + out = resolve_chart_bars( + "binance", + "BTC/USDT", + "15m", + remote_fetch, + db_path=self.db, + tail_refresh=True, + ) + self.assertTrue(out.get("ok")) + self.assertGreaterEqual(len(out.get("candles") or []), 100) + self.assertGreater(int(out.get("fetched") or 0), 0) + self.assertTrue(any(int(c.get("limit") or 0) > 30 for c in fetch_calls)) + + def test_resolve_before_ms_exhausted(self): + init_db(self.db) + + def remote_fetch(**kwargs): + return {"ok": False, "msg": "no remote"} + + out = resolve_chart_bars( + "okx", + "BTC/USDT", + "5m", + remote_fetch, + db_path=self.db, + limit=100, + before_ms=int(time.time() * 1000), + ) + self.assertTrue(out.get("ok")) + self.assertEqual(out.get("candles"), []) + self.assertTrue(out.get("exhausted")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_hub_macro_calendar_lib.py b/tests/test_hub_macro_calendar_lib.py index 33815be..c4c2f32 100644 --- a/tests/test_hub_macro_calendar_lib.py +++ b/tests/test_hub_macro_calendar_lib.py @@ -1,73 +1,73 @@ -import os -import tempfile -import unittest -from pathlib import Path -from unittest import mock - -from hub_macro_calendar_lib import ( - build_banner_message, - create_event, - delete_event, - enrich_alert, - init_db, - list_active_alerts, - list_events, - update_event, -) - - -class HubMacroCalendarLibTests(unittest.TestCase): - def setUp(self): - self.tmp = tempfile.TemporaryDirectory() - self.db_path = Path(self.tmp.name) / "macro.db" - init_db(self.db_path) - - def tearDown(self): - self.tmp.cleanup() - - def test_create_and_list(self): - row = create_event("cpi", "2026-06-18 20:30", note="核心CPI", db_path=self.db_path) - self.assertEqual(row["event_type"], "cpi") - self.assertEqual(row["event_at"], "2026-06-18 20:30") - rows = list_events(now_ms=row["event_at_ms"] - 86400000, db_path=self.db_path) - self.assertEqual(len(rows), 1) - - def test_duplicate_rejected(self): - create_event("fomc", "2026-07-01 02:00", db_path=self.db_path) - with self.assertRaises(ValueError): - create_event("fomc", "2026-07-01 02:00", db_path=self.db_path) - - def test_active_window_and_messages(self): - row = create_event("employment", "2026-06-18 20:30", db_path=self.db_path) - t0 = int(row["event_at_ms"]) - inside = enrich_alert(row, now_ms=t0 - 30 * 60 * 1000) - self.assertIsNotNone(inside) - self.assertEqual(inside["phase"], "imminent") - outside = enrich_alert(row, now_ms=t0 - 2 * 3600 * 1000) - self.assertIsNone(outside) - alerts = list_active_alerts(now_ms=t0 + 15 * 60 * 1000, db_path=self.db_path) - self.assertEqual(len(alerts), 1) - msg_pos = build_banner_message(alerts[0], has_positions=True) - msg_flat = build_banner_message(alerts[0], has_positions=False) - self.assertIn("注意仓位风险", msg_pos) - self.assertIn("建议等待", msg_flat) - - def test_update_and_delete(self): - row = create_event("cpi", "2026-06-18 20:30", db_path=self.db_path) - updated = update_event( - row["id"], - event_at="2026-06-18 21:00", - note="修正时间", - db_path=self.db_path, - ) - self.assertEqual(updated["event_at"], "2026-06-18 21:00") - self.assertTrue(delete_event(row["id"], db_path=self.db_path)) - self.assertEqual(len(list_events(now_ms=updated["event_at_ms"], db_path=self.db_path)), 0) - - def test_invalid_type(self): - with self.assertRaises(ValueError): - create_event("nfp", "2026-06-18 20:30", db_path=self.db_path) - - -if __name__ == "__main__": - unittest.main() +import os +import tempfile +import unittest +from pathlib import Path +from unittest import mock + +from lib.hub.hub_macro_calendar_lib import ( + build_banner_message, + create_event, + delete_event, + enrich_alert, + init_db, + list_active_alerts, + list_events, + update_event, +) + + +class HubMacroCalendarLibTests(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.TemporaryDirectory() + self.db_path = Path(self.tmp.name) / "macro.db" + init_db(self.db_path) + + def tearDown(self): + self.tmp.cleanup() + + def test_create_and_list(self): + row = create_event("cpi", "2026-06-18 20:30", note="核心CPI", db_path=self.db_path) + self.assertEqual(row["event_type"], "cpi") + self.assertEqual(row["event_at"], "2026-06-18 20:30") + rows = list_events(now_ms=row["event_at_ms"] - 86400000, db_path=self.db_path) + self.assertEqual(len(rows), 1) + + def test_duplicate_rejected(self): + create_event("fomc", "2026-07-01 02:00", db_path=self.db_path) + with self.assertRaises(ValueError): + create_event("fomc", "2026-07-01 02:00", db_path=self.db_path) + + def test_active_window_and_messages(self): + row = create_event("employment", "2026-06-18 20:30", db_path=self.db_path) + t0 = int(row["event_at_ms"]) + inside = enrich_alert(row, now_ms=t0 - 30 * 60 * 1000) + self.assertIsNotNone(inside) + self.assertEqual(inside["phase"], "imminent") + outside = enrich_alert(row, now_ms=t0 - 2 * 3600 * 1000) + self.assertIsNone(outside) + alerts = list_active_alerts(now_ms=t0 + 15 * 60 * 1000, db_path=self.db_path) + self.assertEqual(len(alerts), 1) + msg_pos = build_banner_message(alerts[0], has_positions=True) + msg_flat = build_banner_message(alerts[0], has_positions=False) + self.assertIn("注意仓位风险", msg_pos) + self.assertIn("建议等待", msg_flat) + + def test_update_and_delete(self): + row = create_event("cpi", "2026-06-18 20:30", db_path=self.db_path) + updated = update_event( + row["id"], + event_at="2026-06-18 21:00", + note="修正时间", + db_path=self.db_path, + ) + self.assertEqual(updated["event_at"], "2026-06-18 21:00") + self.assertTrue(delete_event(row["id"], db_path=self.db_path)) + self.assertEqual(len(list_events(now_ms=updated["event_at_ms"], db_path=self.db_path)), 0) + + def test_invalid_type(self): + with self.assertRaises(ValueError): + create_event("nfp", "2026-06-18 20:30", db_path=self.db_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_hub_monitor_payload.py b/tests/test_hub_monitor_payload.py index 4912c10..b042a08 100644 --- a/tests/test_hub_monitor_payload.py +++ b/tests/test_hub_monitor_payload.py @@ -1,39 +1,39 @@ -"""hub /api/hub/monitor:enrich 局部返回时须保留 keys。""" -from __future__ import annotations - -import sys -import unittest -from pathlib import Path - -ROOT = Path(__file__).resolve().parents[1] -sys.path.insert(0, str(ROOT)) - -from hub_bridge import build_hub_monitor_payload # noqa: E402 - - -class TestHubMonitorPayload(unittest.TestCase): - def test_partial_enrich_keeps_keys(self): - keys = [{"id": 7, "symbol": "BTC/USDT"}] - orders = [{"id": 1}] - trends = [{"id": 9, "symbol": "ETH/USDT"}] - rolls = [] - - def enrich_only_trends(**_kw): - return {"trends": [{"id": 9, "add_count": 2}]} - - out = build_hub_monitor_payload( - keys=keys, - orders=orders, - trends=trends, - rolls=rolls, - enrich=enrich_only_trends, - ) - self.assertTrue(out["ok"]) - self.assertEqual(out["keys"], keys) - self.assertEqual(out["orders"], orders) - self.assertEqual(out["rolls"], rolls) - self.assertEqual(out["trends"][0]["add_count"], 2) - - -if __name__ == "__main__": - unittest.main() +"""hub /api/hub/monitor:enrich 局部返回时须保留 keys。""" +from __future__ import annotations + +import sys +import unittest +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT)) + +from lib.hub.hub_bridge import build_hub_monitor_payload # noqa: E402 + + +class TestHubMonitorPayload(unittest.TestCase): + def test_partial_enrich_keeps_keys(self): + keys = [{"id": 7, "symbol": "BTC/USDT"}] + orders = [{"id": 1}] + trends = [{"id": 9, "symbol": "ETH/USDT"}] + rolls = [] + + def enrich_only_trends(**_kw): + return {"trends": [{"id": 9, "add_count": 2}]} + + out = build_hub_monitor_payload( + keys=keys, + orders=orders, + trends=trends, + rolls=rolls, + enrich=enrich_only_trends, + ) + self.assertTrue(out["ok"]) + self.assertEqual(out["keys"], keys) + self.assertEqual(out["orders"], orders) + self.assertEqual(out["rolls"], rolls) + self.assertEqual(out["trends"][0]["add_count"], 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_hub_ohlcv_lib.py b/tests/test_hub_ohlcv_lib.py index 99786ff..ba9e637 100644 --- a/tests/test_hub_ohlcv_lib.py +++ b/tests/test_hub_ohlcv_lib.py @@ -1,223 +1,222 @@ -"""hub_ohlcv_lib:分页拉取(Gate 等单次不足 chunk 时仍继续)。""" -from __future__ import annotations - -import unittest - -from hub_ohlcv_lib import ( - aggregate_ohlcv_bars, - bars_spacing_matches_timeframe, - fetch_ohlcv_for_hub, - normalize_price_tick, -) - - -class _FakeExchange: - def __init__(self, pages, *, timeframes=None): - self.pages = list(pages) - self.calls = [] - self.markets = {} - self.timeframes = timeframes if timeframes is not None else {} - - def fetch_ohlcv(self, symbol, timeframe=None, since=None, limit=None): - self.calls.append( - {"symbol": symbol, "since": since, "limit": limit, "timeframe": timeframe} - ) - if not self.pages: - return [] - page = self.pages.pop(0) - if since is None: - return page - return [b for b in page if b[0] >= since] - - -class TestHubOhlcvLib(unittest.TestCase): - def test_normalize_price_tick_snaps_powers_of_ten(self): - self.assertAlmostEqual(normalize_price_tick(0.00001), 0.00001) - self.assertAlmostEqual(normalize_price_tick(0.001), 0.001) - self.assertIsNone(normalize_price_tick(0)) - - def test_price_tick_from_decimal_precision(self): - class _Ex: - markets = {"BTC/USDT:USDT": {"precision": {"price": 2}, "info": {}, "limits": {}}} - - def load_markets(self): - return self.markets - - def market(self, sym): - return self.markets[sym] - - def price_to_precision(self, sym, price): - return "12345.67" - - tick = __import__("hub_ohlcv_lib", fromlist=["price_tick_from_market"]).price_tick_from_market( - _Ex(), "BTC/USDT:USDT" - ) - self.assertAlmostEqual(tick, 0.01) - - def test_price_tick_from_binance_price_filter(self): - class _Ex: - markets = { - "BTC/USDT:USDT": { - "precision": {"price": 2}, - "info": { - "filters": [ - {"filterType": "PRICE_FILTER", "tickSize": "0.10"}, - {"filterType": "LOT_SIZE", "stepSize": "0.001"}, - ] - }, - "limits": {}, - } - } - - def load_markets(self): - return self.markets - - def market(self, sym): - return self.markets[sym] - - def price_to_precision(self, sym, price): - return "12345.6" - - from hub_ohlcv_lib import price_tick_from_market - - tick = price_tick_from_market(_Ex(), "BTC/USDT:USDT") - self.assertAlmostEqual(tick, 0.10) - - def test_price_tick_from_info_tick_size(self): - class _Ex: - markets = { - "INJ/USDT:USDT": { - "precision": {"price": 4}, - "info": {"tickSize": "0.001"}, - "limits": {}, - } - } - - def load_markets(self): - return self.markets - - def market(self, sym): - return self.markets[sym] - - def price_to_precision(self, sym, price): - return "7.123" - - from hub_ohlcv_lib import price_tick_from_market - - tick = price_tick_from_market(_Ex(), "INJ/USDT:USDT") - self.assertAlmostEqual(tick, 0.001) - - def test_full_fetch_without_since_paginates_okx_style(self): - """OKX 等无 since 单次约 300 根,须分页至 limit。""" - from hub_ohlcv_lib import TIMEFRAME_MS - - step = TIMEFRAME_MS["1h"] - want = 1000 - base = max(0, int(__import__("time").time() * 1000) - want * step) - pages = [ - [[base + i * step, 1.0, 1.1, 0.9, 1.05, 100.0] for i in range(300)], - [[base + (300 + i) * step, 2.0, 2.1, 1.9, 2.05, 200.0] for i in range(300)], - [[base + (600 + i) * step, 3.0, 3.1, 2.9, 3.05, 300.0] for i in range(300)], - [[base + (900 + i) * step, 4.0, 4.1, 3.9, 4.05, 400.0] for i in range(100)], - ] - ex = _FakeExchange(pages) - - out = fetch_ohlcv_for_hub( - symbol="ONDO/USDT", - timeframe="1h", - since_ms=None, - limit=want, - normalize_symbol_input=lambda s: str(s).strip().upper(), - normalize_exchange_symbol=lambda s: f"{s}:USDT" if ":" not in s else s, - ensure_markets_loaded=lambda: None, - exchange=ex, - ) - self.assertTrue(out.get("ok")) - self.assertEqual(len(out.get("bars") or []), 1000) - self.assertGreaterEqual(len(ex.calls), 4) - self.assertAlmostEqual(out["bars"][-1]["close"], 4.05) - - def test_pagination_continues_when_page_smaller_than_chunk(self): - """Gate 等常返回 299 根/次,不应误判为已到末尾。""" - base = 1_700_000_000_000 - step = 4 * 60 * 60 * 1000 - page1 = [ - [base + i * step, 1.0, 1.1, 0.9, 1.05, 100.0] for i in range(299) - ] - page2 = [ - [base + (299 + i) * step, 2.0, 2.1, 1.9, 2.05, 200.0] for i in range(299) - ] - page3 = [ - [base + (598 + i) * step, 3.0, 3.1, 2.9, 3.05, 300.0] for i in range(50) - ] - ex = _FakeExchange([page1, page2, page3]) - - out = fetch_ohlcv_for_hub( - symbol="INJ/USDT", - timeframe="4h", - since_ms=base, - limit=600, - normalize_symbol_input=lambda s: str(s).strip().upper(), - normalize_exchange_symbol=lambda s: f"{s}:USDT" if ":" not in s else s, - ensure_markets_loaded=lambda: None, - exchange=ex, - ) - self.assertTrue(out.get("ok")) - self.assertEqual(len(out.get("bars") or []), 600) - self.assertGreaterEqual(len(ex.calls), 3) - self.assertAlmostEqual(out["bars"][-1]["close"], 3.05) - - def test_pagination_stops_when_next_since_reaches_now(self): - """Gate 等:分页 since 不得越过当前时间,避免 from>to。""" - from hub_ohlcv_lib import TIMEFRAME_MS - - step = TIMEFRAME_MS["1d"] - now_ms = int(__import__("time").time() * 1000) - # 最后一页最后一根 K 的 next_since 将 >= now_ms,应停止不再请求 - last_open = ((now_ms // step) - 2) * step - page = [ - [last_open - step, 1.0, 1.1, 0.9, 1.0, 10.0], - [last_open, 1.1, 1.2, 1.0, 1.1, 11.0], - ] - ex = _FakeExchange([page]) - - out = fetch_ohlcv_for_hub( - symbol="ONDO/USDT", - timeframe="1d", - since_ms=last_open - step * 5, - limit=10, - normalize_symbol_input=lambda s: str(s).strip().upper(), - normalize_exchange_symbol=lambda s: f"{s}:USDT" if ":" not in s else s, - ensure_markets_loaded=lambda: None, - exchange=ex, - ) - self.assertTrue(out.get("ok")) - self.assertGreaterEqual(len(out.get("bars") or []), 2) - self.assertLessEqual(len(ex.calls), 4) - - def test_aggregate_ohlcv_bars_buckets(self): - from hub_ohlcv_lib import TIMEFRAME_MS - - h1 = TIMEFRAME_MS["1h"] - h4 = TIMEFRAME_MS["4h"] - base = (1_700_000_000_000 // h4) * h4 - src = [ - { - "open_time_ms": base + i * h1, - "open": 1.0, - "high": 2.0, - "low": 0.5, - "close": 1.5, - "volume": 1.0, - } - for i in range(4) - ] - out = aggregate_ohlcv_bars(src, "4h") - self.assertEqual(len(out), 1) - self.assertEqual(out[0]["volume"], 4.0) - self.assertEqual(out[0]["high"], 2.0) - self.assertEqual(out[0]["low"], 0.5) - - -if __name__ == "__main__": - unittest.main() +"""hub_ohlcv_lib:分页拉取(Gate 等单次不足 chunk 时仍继续)。""" +from __future__ import annotations + +import unittest + +from lib.hub.hub_ohlcv_lib import ( + aggregate_ohlcv_bars, + bars_spacing_matches_timeframe, + fetch_ohlcv_for_hub, + normalize_price_tick, + price_tick_from_market, +) + + +class _FakeExchange: + def __init__(self, pages, *, timeframes=None): + self.pages = list(pages) + self.calls = [] + self.markets = {} + self.timeframes = timeframes if timeframes is not None else {} + + def fetch_ohlcv(self, symbol, timeframe=None, since=None, limit=None): + self.calls.append( + {"symbol": symbol, "since": since, "limit": limit, "timeframe": timeframe} + ) + if not self.pages: + return [] + page = self.pages.pop(0) + if since is None: + return page + return [b for b in page if b[0] >= since] + + +class TestHubOhlcvLib(unittest.TestCase): + def test_normalize_price_tick_snaps_powers_of_ten(self): + self.assertAlmostEqual(normalize_price_tick(0.00001), 0.00001) + self.assertAlmostEqual(normalize_price_tick(0.001), 0.001) + self.assertIsNone(normalize_price_tick(0)) + + def test_price_tick_from_decimal_precision(self): + class _Ex: + markets = {"BTC/USDT:USDT": {"precision": {"price": 2}, "info": {}, "limits": {}}} + + def load_markets(self): + return self.markets + + def market(self, sym): + return self.markets[sym] + + def price_to_precision(self, sym, price): + return "12345.67" + + tick = price_tick_from_market(_Ex(), "BTC/USDT:USDT") + self.assertAlmostEqual(tick, 0.01) + + def test_price_tick_from_binance_price_filter(self): + class _Ex: + markets = { + "BTC/USDT:USDT": { + "precision": {"price": 2}, + "info": { + "filters": [ + {"filterType": "PRICE_FILTER", "tickSize": "0.10"}, + {"filterType": "LOT_SIZE", "stepSize": "0.001"}, + ] + }, + "limits": {}, + } + } + + def load_markets(self): + return self.markets + + def market(self, sym): + return self.markets[sym] + + def price_to_precision(self, sym, price): + return "12345.6" + + from lib.hub.hub_ohlcv_lib import price_tick_from_market + + tick = price_tick_from_market(_Ex(), "BTC/USDT:USDT") + self.assertAlmostEqual(tick, 0.10) + + def test_price_tick_from_info_tick_size(self): + class _Ex: + markets = { + "INJ/USDT:USDT": { + "precision": {"price": 4}, + "info": {"tickSize": "0.001"}, + "limits": {}, + } + } + + def load_markets(self): + return self.markets + + def market(self, sym): + return self.markets[sym] + + def price_to_precision(self, sym, price): + return "7.123" + + from lib.hub.hub_ohlcv_lib import price_tick_from_market + + tick = price_tick_from_market(_Ex(), "INJ/USDT:USDT") + self.assertAlmostEqual(tick, 0.001) + + def test_full_fetch_without_since_paginates_okx_style(self): + """OKX 等无 since 单次约 300 根,须分页至 limit。""" + from lib.hub.hub_ohlcv_lib import TIMEFRAME_MS + + step = TIMEFRAME_MS["1h"] + want = 1000 + base = max(0, int(__import__("time").time() * 1000) - want * step) + pages = [ + [[base + i * step, 1.0, 1.1, 0.9, 1.05, 100.0] for i in range(300)], + [[base + (300 + i) * step, 2.0, 2.1, 1.9, 2.05, 200.0] for i in range(300)], + [[base + (600 + i) * step, 3.0, 3.1, 2.9, 3.05, 300.0] for i in range(300)], + [[base + (900 + i) * step, 4.0, 4.1, 3.9, 4.05, 400.0] for i in range(100)], + ] + ex = _FakeExchange(pages) + + out = fetch_ohlcv_for_hub( + symbol="ONDO/USDT", + timeframe="1h", + since_ms=None, + limit=want, + normalize_symbol_input=lambda s: str(s).strip().upper(), + normalize_exchange_symbol=lambda s: f"{s}:USDT" if ":" not in s else s, + ensure_markets_loaded=lambda: None, + exchange=ex, + ) + self.assertTrue(out.get("ok")) + self.assertEqual(len(out.get("bars") or []), 1000) + self.assertGreaterEqual(len(ex.calls), 4) + self.assertAlmostEqual(out["bars"][-1]["close"], 4.05) + + def test_pagination_continues_when_page_smaller_than_chunk(self): + """Gate 等常返回 299 根/次,不应误判为已到末尾。""" + base = 1_700_000_000_000 + step = 4 * 60 * 60 * 1000 + page1 = [ + [base + i * step, 1.0, 1.1, 0.9, 1.05, 100.0] for i in range(299) + ] + page2 = [ + [base + (299 + i) * step, 2.0, 2.1, 1.9, 2.05, 200.0] for i in range(299) + ] + page3 = [ + [base + (598 + i) * step, 3.0, 3.1, 2.9, 3.05, 300.0] for i in range(50) + ] + ex = _FakeExchange([page1, page2, page3]) + + out = fetch_ohlcv_for_hub( + symbol="INJ/USDT", + timeframe="4h", + since_ms=base, + limit=600, + normalize_symbol_input=lambda s: str(s).strip().upper(), + normalize_exchange_symbol=lambda s: f"{s}:USDT" if ":" not in s else s, + ensure_markets_loaded=lambda: None, + exchange=ex, + ) + self.assertTrue(out.get("ok")) + self.assertEqual(len(out.get("bars") or []), 600) + self.assertGreaterEqual(len(ex.calls), 3) + self.assertAlmostEqual(out["bars"][-1]["close"], 3.05) + + def test_pagination_stops_when_next_since_reaches_now(self): + """Gate 等:分页 since 不得越过当前时间,避免 from>to。""" + from lib.hub.hub_ohlcv_lib import TIMEFRAME_MS + + step = TIMEFRAME_MS["1d"] + now_ms = int(__import__("time").time() * 1000) + # 最后一页最后一根 K 的 next_since 将 >= now_ms,应停止不再请求 + last_open = ((now_ms // step) - 2) * step + page = [ + [last_open - step, 1.0, 1.1, 0.9, 1.0, 10.0], + [last_open, 1.1, 1.2, 1.0, 1.1, 11.0], + ] + ex = _FakeExchange([page]) + + out = fetch_ohlcv_for_hub( + symbol="ONDO/USDT", + timeframe="1d", + since_ms=last_open - step * 5, + limit=10, + normalize_symbol_input=lambda s: str(s).strip().upper(), + normalize_exchange_symbol=lambda s: f"{s}:USDT" if ":" not in s else s, + ensure_markets_loaded=lambda: None, + exchange=ex, + ) + self.assertTrue(out.get("ok")) + self.assertGreaterEqual(len(out.get("bars") or []), 2) + self.assertLessEqual(len(ex.calls), 4) + + def test_aggregate_ohlcv_bars_buckets(self): + from lib.hub.hub_ohlcv_lib import TIMEFRAME_MS + + h1 = TIMEFRAME_MS["1h"] + h4 = TIMEFRAME_MS["4h"] + base = (1_700_000_000_000 // h4) * h4 + src = [ + { + "open_time_ms": base + i * h1, + "open": 1.0, + "high": 2.0, + "low": 0.5, + "close": 1.5, + "volume": 1.0, + } + for i in range(4) + ] + out = aggregate_ohlcv_bars(src, "4h") + self.assertEqual(len(out), 1) + self.assertEqual(out[0]["volume"], 4.0) + self.assertEqual(out[0]["high"], 2.0) + self.assertEqual(out[0]["low"], 0.5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_hub_supervisor_lib.py b/tests/test_hub_supervisor_lib.py index 06385d5..9479c7d 100644 --- a/tests/test_hub_supervisor_lib.py +++ b/tests/test_hub_supervisor_lib.py @@ -5,7 +5,12 @@ import json import sys from pathlib import Path -import pytest +try: + import pytest +except ImportError: # pragma: no cover + import unittest + + raise unittest.SkipTest("pytest not installed") ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(ROOT)) diff --git a/tests/test_hub_symbol_archive_lib.py b/tests/test_hub_symbol_archive_lib.py index 7e71fe4..b07396d 100644 --- a/tests/test_hub_symbol_archive_lib.py +++ b/tests/test_hub_symbol_archive_lib.py @@ -1,348 +1,348 @@ -"""币种档案库:5m 聚合与视窗计算。""" - -from __future__ import annotations - -import tempfile -from pathlib import Path - -from hub_ohlcv_lib import aggregate_ohlcv_bars -from datetime import datetime, timezone -from zoneinfo import ZoneInfo - -from hub_symbol_archive_lib import ( - CHART_DISPLAY_TZ, - _compute_period_stats, - _fill_missing_bars, - init_db, - list_daily_trades, - load_symbol_trades, - ms_to_wall_clock_str, - parse_wall_clock_ms, - resolve_archive_chart, - trading_day_bounds_ms, - upsert_bars_5m, - upsert_trade_overlay, - list_symbol_rows, - upsert_trades_cache, -) - - -def _seed_5m_bars( - db: Path, - start_ms: int, - count: int, - step: int = 300_000, - *, - ex: str = "gate", - sym: str = "ONDO", -) -> None: - bars = [] - price = 1.0 - for i in range(count): - o = start_ms + i * step - price += 0.001 - bars.append( - { - "open_time_ms": o, - "open": price, - "high": price + 0.002, - "low": price - 0.001, - "close": price + 0.001, - "volume": 100 + i, - } - ) - upsert_bars_5m(ex, sym, bars, db_path=db) - - -def test_aggregate_15m_from_5m(): - start = 1_700_000_000_000 - bars = [] - for i in range(6): - t = start + i * 300_000 - bars.append( - { - "open_time_ms": t, - "open": 1.0, - "high": 1.1, - "low": 0.9, - "close": 1.05, - "volume": 10, - } - ) - agg = aggregate_ohlcv_bars(bars, "15m") - assert len(agg) >= 1 - assert agg[-1]["close"] == bars[-1]["close"] - assert agg[0]["open_time_ms"] <= agg[1]["open_time_ms"] - - -def test_resolve_archive_chart_15m(): - with tempfile.TemporaryDirectory() as td: - db = Path(td) / "archive.db" - init_db(db) - anchor = 1_700_000_000_000 - _seed_5m_bars(db, anchor - 50 * 300_000, 120) - out = resolve_archive_chart( - "gate", - "ONDO", - "15m", - anchor_ms=anchor, - mode="hold", - bars=40, - db_path=db, - ) - assert out["ok"] is True - assert out["timeframe"] == "15m" - assert len(out["candles"]) >= 10 - - -def test_fill_missing_bars_continuity(): - period = 300_000 - start = (1_700_000_000_000 // period) * period - bars = [ - { - "open_time_ms": start, - "open": 1.0, - "high": 1.1, - "low": 0.9, - "close": 1.05, - "volume": 10, - }, - { - "open_time_ms": start + period * 2, - "open": 1.05, - "high": 1.15, - "low": 1.0, - "close": 1.1, - "volume": 8, - }, - ] - filled = _fill_missing_bars(bars, period, start, start + period * 2) - assert len(filled) >= 3 - assert any(b.get("filled") for b in filled) - - -def test_resolve_archive_chart_history_range(): - with tempfile.TemporaryDirectory() as td: - db = Path(td) / "archive.db" - init_db(db) - open_ms = 1_700_000_000_000 - close_ms = open_ms + 6 * 3600_000 - _seed_5m_bars(db, open_ms - 20 * 300_000, 200, ex="gate", sym="BNB/USDT") - out = resolve_archive_chart( - "gate", - "BNB/USDT", - "15m", - opened_ms=open_ms, - closed_ms=close_ms, - mode="hold", - range_mode="history", - db_path=db, - ) - assert out["ok"] is True - assert out.get("range_mode") == "history" - assert out.get("window_end_ms") <= close_ms + 4 * 3600_000 - assert len(out["candles"]) >= 40 - - -def test_sync_prunes_missing_trades(): - with tempfile.TemporaryDirectory() as td: - db = Path(td) / "archive.db" - init_db(db) - upsert_trades_cache( - "gate", - [ - {"id": 1, "symbol": "BNB/USDT", "result": "止损", "pnl_amount": -1}, - {"id": 2, "symbol": "BNB/USDT", "result": "止盈", "pnl_amount": 1}, - ], - db_path=db, - prune_missing=False, - ) - stats = upsert_trades_cache( - "gate", - [{"id": 1, "symbol": "BNB/USDT", "result": "止损", "pnl_amount": -1}], - db_path=db, - prune_missing=True, - ) - rows = load_symbol_trades("gate", "BNB/USDT", db_path=db) - assert len(rows) == 1 - assert rows[0]["trade_id"] == 1 - assert stats["removed"] == 1 - - -def test_list_with_overlay_filters(): - with tempfile.TemporaryDirectory() as td: - db = Path(td) / "archive.db" - init_db(db) - upsert_trades_cache( - "gate", - [ - { - "id": 1, - "symbol": "ONDO", - "direction": "long", - "result": "止盈", - "pnl_amount": 12.5, - "opened_at": "2026-01-01 10:00:00", - "closed_at": "2026-01-01 12:00:00", - "opened_at_ms": 1_700_000_000_000, - "closed_at_ms": 1_700_007_200_000, - }, - { - "id": 2, - "symbol": "ONDO", - "direction": "short", - "result": "止损", - "pnl_amount": -3.2, - "opened_at": "2026-01-02 10:00:00", - "closed_at": "2026-01-02 11:00:00", - "opened_at_ms": 1_700_086_400_000, - "closed_at_ms": 1_700_090_000_000, - }, - ], - db_path=db, - ) - upsert_trade_overlay("gate", 2, behavior_tag="sick", note="追高", db_path=db) - rows = list_symbol_rows(db_path=db) - assert len(rows) == 1 - assert rows[0]["trade_count"] == 2 - sick_only = list_symbol_rows(filter_sick=True, db_path=db) - assert len(sick_only) == 1 - profit_only = list_symbol_rows(filter_profit=True, db_path=db) - assert len(profit_only) == 1 - - -def test_parse_wall_clock_ms_uses_utc_plus_8(): - ms = parse_wall_clock_ms("2026-06-07 20:30:00") - assert ms is not None - dt_utc = datetime.fromtimestamp(ms / 1000.0, tz=timezone.utc) - dt_bj = dt_utc.astimezone(CHART_DISPLAY_TZ) - assert dt_bj.strftime("%Y-%m-%d %H:%M:%S") == "2026-06-07 20:30:00" - assert ms_to_wall_clock_str(ms) == "2026-06-07 20:30:00" - assert parse_wall_clock_ms("2026-06-07 20:30") == ms - - -def test_parse_wall_clock_ms_accepts_epoch_strings(): - ms = 1_700_000_000_000 - assert parse_wall_clock_ms(str(ms)) == ms - assert parse_wall_clock_ms(str(ms // 1000)) == ms - - -def test_resolve_archive_chart_history_uses_trade_span_not_200_bars(): - with tempfile.TemporaryDirectory() as td: - db = Path(td) / "archive.db" - init_db(db) - opened = 1_700_000_000_000 - closed = opened + 20 * 24 * 3600_000 - _seed_5m_bars(db, opened - 35 * 24 * 3600_000, 40 * 24 * 12) - out = resolve_archive_chart( - "gate", - "ONDO", - "15m", - opened_ms=opened, - closed_ms=closed, - mode="hold", - bars=200, - range_mode="history", - db_path=db, - ) - assert out["ok"] is True - assert out["range_mode"] == "history" - assert out["bar_count"] > 200 - - -def test_upsert_forces_sync_exchange_key(): - with tempfile.TemporaryDirectory() as td: - db = Path(td) / "archive.db" - init_db(db) - upsert_trades_cache( - "gate_bot", - [ - { - "id": 77, - "exchange_key": "gate", - "account_exchange_key": "gate", - "symbol": "ETH/USDT", - "result": "止损", - "pnl_amount": -1, - "opened_at_ms": 1_700_000_000_000, - "closed_at_ms": 1_700_007_200_000, - } - ], - db_path=db, - ) - rows = load_symbol_trades("gate_bot", "ETH/USDT", db_path=db) - assert len(rows) == 1 - assert rows[0]["exchange_key"] == "gate_bot" - assert "account_exchange_key" not in rows[0] - - -def test_compute_period_stats_win_loss_metrics(): - rows = [ - {"exchange_key": "binance", "pnl_amount": 10.0, "behavior_tag": ""}, - {"exchange_key": "binance", "pnl_amount": 4.0, "behavior_tag": ""}, - {"exchange_key": "okx", "pnl_amount": -3.0, "behavior_tag": "sick"}, - {"exchange_key": "okx", "pnl_amount": -6.0, "behavior_tag": ""}, - ] - st = _compute_period_stats(rows) - assert st["open_count"] == 4 - assert st["win_count"] == 2 - assert st["loss_count"] == 2 - assert st["avg_win"] == 7.0 - assert st["avg_loss"] == -4.5 - assert st["max_win"] == 10.0 - assert st["max_loss"] == -6.0 - assert st["win_rate"] == 50.0 - assert st["profit_loss_ratio"] == round(7.0 / 4.5, 2) - assert st["sick_count"] == 1 - assert st["pnl_total"] == 5.0 - assert st["pnl_ex_sick"] == 8.0 - assert st["by_exchange"]["binance"]["win_count"] == 2 - assert st["by_exchange"]["binance"]["win_rate"] == 100.0 - assert st["by_exchange"]["binance"]["profit_loss_ratio"] is None - - -def test_list_daily_trades_search_filters_stats(): - with tempfile.TemporaryDirectory() as td: - db = Path(td) / "archive.db" - init_db(db) - day = "2023-11-15" - start_ms, _ = trading_day_bounds_ms(day) - btc_close = start_ms + 3_600_000 - eth_close = start_ms + 7_200_000 - upsert_trades_cache( - "gate", - [ - { - "id": 1, - "symbol": "BTC/USDT", - "result": "止盈", - "pnl_amount": 5.0, - "opened_at_ms": start_ms, - "closed_at_ms": btc_close, - }, - { - "id": 2, - "symbol": "ETH/USDT", - "result": "止损", - "pnl_amount": -2.0, - "opened_at_ms": btc_close, - "closed_at_ms": eth_close, - }, - ], - db_path=db, - ) - payload = list_daily_trades( - period="range", - date_from=day, - date_to=day, - search="btc", - db_path=db, - ) - assert len(payload["trades"]) == 1 - assert payload["trades"][0]["symbol"] == "BTC/USDT" - st = payload["stats"] - assert st["open_count"] == 1 - assert st["win_count"] == 1 - assert st["loss_count"] == 0 - assert st["max_win"] == 5.0 - assert st["pnl_total"] == 5.0 +"""币种档案库:5m 聚合与视窗计算。""" + +from __future__ import annotations + +import tempfile +from pathlib import Path + +from lib.hub.hub_ohlcv_lib import aggregate_ohlcv_bars +from datetime import datetime, timezone +from zoneinfo import ZoneInfo + +from lib.hub.hub_symbol_archive_lib import ( + CHART_DISPLAY_TZ, + _compute_period_stats, + _fill_missing_bars, + init_db, + list_daily_trades, + load_symbol_trades, + ms_to_wall_clock_str, + parse_wall_clock_ms, + resolve_archive_chart, + trading_day_bounds_ms, + upsert_bars_5m, + upsert_trade_overlay, + list_symbol_rows, + upsert_trades_cache, +) + + +def _seed_5m_bars( + db: Path, + start_ms: int, + count: int, + step: int = 300_000, + *, + ex: str = "gate", + sym: str = "ONDO", +) -> None: + bars = [] + price = 1.0 + for i in range(count): + o = start_ms + i * step + price += 0.001 + bars.append( + { + "open_time_ms": o, + "open": price, + "high": price + 0.002, + "low": price - 0.001, + "close": price + 0.001, + "volume": 100 + i, + } + ) + upsert_bars_5m(ex, sym, bars, db_path=db) + + +def test_aggregate_15m_from_5m(): + start = 1_700_000_000_000 + bars = [] + for i in range(6): + t = start + i * 300_000 + bars.append( + { + "open_time_ms": t, + "open": 1.0, + "high": 1.1, + "low": 0.9, + "close": 1.05, + "volume": 10, + } + ) + agg = aggregate_ohlcv_bars(bars, "15m") + assert len(agg) >= 1 + assert agg[-1]["close"] == bars[-1]["close"] + assert agg[0]["open_time_ms"] <= agg[1]["open_time_ms"] + + +def test_resolve_archive_chart_15m(): + with tempfile.TemporaryDirectory() as td: + db = Path(td) / "archive.db" + init_db(db) + anchor = 1_700_000_000_000 + _seed_5m_bars(db, anchor - 50 * 300_000, 120) + out = resolve_archive_chart( + "gate", + "ONDO", + "15m", + anchor_ms=anchor, + mode="hold", + bars=40, + db_path=db, + ) + assert out["ok"] is True + assert out["timeframe"] == "15m" + assert len(out["candles"]) >= 10 + + +def test_fill_missing_bars_continuity(): + period = 300_000 + start = (1_700_000_000_000 // period) * period + bars = [ + { + "open_time_ms": start, + "open": 1.0, + "high": 1.1, + "low": 0.9, + "close": 1.05, + "volume": 10, + }, + { + "open_time_ms": start + period * 2, + "open": 1.05, + "high": 1.15, + "low": 1.0, + "close": 1.1, + "volume": 8, + }, + ] + filled = _fill_missing_bars(bars, period, start, start + period * 2) + assert len(filled) >= 3 + assert any(b.get("filled") for b in filled) + + +def test_resolve_archive_chart_history_range(): + with tempfile.TemporaryDirectory() as td: + db = Path(td) / "archive.db" + init_db(db) + open_ms = 1_700_000_000_000 + close_ms = open_ms + 6 * 3600_000 + _seed_5m_bars(db, open_ms - 20 * 300_000, 200, ex="gate", sym="BNB/USDT") + out = resolve_archive_chart( + "gate", + "BNB/USDT", + "15m", + opened_ms=open_ms, + closed_ms=close_ms, + mode="hold", + range_mode="history", + db_path=db, + ) + assert out["ok"] is True + assert out.get("range_mode") == "history" + assert out.get("window_end_ms") <= close_ms + 4 * 3600_000 + assert len(out["candles"]) >= 40 + + +def test_sync_prunes_missing_trades(): + with tempfile.TemporaryDirectory() as td: + db = Path(td) / "archive.db" + init_db(db) + upsert_trades_cache( + "gate", + [ + {"id": 1, "symbol": "BNB/USDT", "result": "止损", "pnl_amount": -1}, + {"id": 2, "symbol": "BNB/USDT", "result": "止盈", "pnl_amount": 1}, + ], + db_path=db, + prune_missing=False, + ) + stats = upsert_trades_cache( + "gate", + [{"id": 1, "symbol": "BNB/USDT", "result": "止损", "pnl_amount": -1}], + db_path=db, + prune_missing=True, + ) + rows = load_symbol_trades("gate", "BNB/USDT", db_path=db) + assert len(rows) == 1 + assert rows[0]["trade_id"] == 1 + assert stats["removed"] == 1 + + +def test_list_with_overlay_filters(): + with tempfile.TemporaryDirectory() as td: + db = Path(td) / "archive.db" + init_db(db) + upsert_trades_cache( + "gate", + [ + { + "id": 1, + "symbol": "ONDO", + "direction": "long", + "result": "止盈", + "pnl_amount": 12.5, + "opened_at": "2026-01-01 10:00:00", + "closed_at": "2026-01-01 12:00:00", + "opened_at_ms": 1_700_000_000_000, + "closed_at_ms": 1_700_007_200_000, + }, + { + "id": 2, + "symbol": "ONDO", + "direction": "short", + "result": "止损", + "pnl_amount": -3.2, + "opened_at": "2026-01-02 10:00:00", + "closed_at": "2026-01-02 11:00:00", + "opened_at_ms": 1_700_086_400_000, + "closed_at_ms": 1_700_090_000_000, + }, + ], + db_path=db, + ) + upsert_trade_overlay("gate", 2, behavior_tag="sick", note="追高", db_path=db) + rows = list_symbol_rows(db_path=db) + assert len(rows) == 1 + assert rows[0]["trade_count"] == 2 + sick_only = list_symbol_rows(filter_sick=True, db_path=db) + assert len(sick_only) == 1 + profit_only = list_symbol_rows(filter_profit=True, db_path=db) + assert len(profit_only) == 1 + + +def test_parse_wall_clock_ms_uses_utc_plus_8(): + ms = parse_wall_clock_ms("2026-06-07 20:30:00") + assert ms is not None + dt_utc = datetime.fromtimestamp(ms / 1000.0, tz=timezone.utc) + dt_bj = dt_utc.astimezone(CHART_DISPLAY_TZ) + assert dt_bj.strftime("%Y-%m-%d %H:%M:%S") == "2026-06-07 20:30:00" + assert ms_to_wall_clock_str(ms) == "2026-06-07 20:30:00" + assert parse_wall_clock_ms("2026-06-07 20:30") == ms + + +def test_parse_wall_clock_ms_accepts_epoch_strings(): + ms = 1_700_000_000_000 + assert parse_wall_clock_ms(str(ms)) == ms + assert parse_wall_clock_ms(str(ms // 1000)) == ms + + +def test_resolve_archive_chart_history_uses_trade_span_not_200_bars(): + with tempfile.TemporaryDirectory() as td: + db = Path(td) / "archive.db" + init_db(db) + opened = 1_700_000_000_000 + closed = opened + 20 * 24 * 3600_000 + _seed_5m_bars(db, opened - 35 * 24 * 3600_000, 40 * 24 * 12) + out = resolve_archive_chart( + "gate", + "ONDO", + "15m", + opened_ms=opened, + closed_ms=closed, + mode="hold", + bars=200, + range_mode="history", + db_path=db, + ) + assert out["ok"] is True + assert out["range_mode"] == "history" + assert out["bar_count"] > 200 + + +def test_upsert_forces_sync_exchange_key(): + with tempfile.TemporaryDirectory() as td: + db = Path(td) / "archive.db" + init_db(db) + upsert_trades_cache( + "gate_bot", + [ + { + "id": 77, + "exchange_key": "gate", + "account_exchange_key": "gate", + "symbol": "ETH/USDT", + "result": "止损", + "pnl_amount": -1, + "opened_at_ms": 1_700_000_000_000, + "closed_at_ms": 1_700_007_200_000, + } + ], + db_path=db, + ) + rows = load_symbol_trades("gate_bot", "ETH/USDT", db_path=db) + assert len(rows) == 1 + assert rows[0]["exchange_key"] == "gate_bot" + assert "account_exchange_key" not in rows[0] + + +def test_compute_period_stats_win_loss_metrics(): + rows = [ + {"exchange_key": "binance", "pnl_amount": 10.0, "behavior_tag": ""}, + {"exchange_key": "binance", "pnl_amount": 4.0, "behavior_tag": ""}, + {"exchange_key": "okx", "pnl_amount": -3.0, "behavior_tag": "sick"}, + {"exchange_key": "okx", "pnl_amount": -6.0, "behavior_tag": ""}, + ] + st = _compute_period_stats(rows) + assert st["open_count"] == 4 + assert st["win_count"] == 2 + assert st["loss_count"] == 2 + assert st["avg_win"] == 7.0 + assert st["avg_loss"] == -4.5 + assert st["max_win"] == 10.0 + assert st["max_loss"] == -6.0 + assert st["win_rate"] == 50.0 + assert st["profit_loss_ratio"] == round(7.0 / 4.5, 2) + assert st["sick_count"] == 1 + assert st["pnl_total"] == 5.0 + assert st["pnl_ex_sick"] == 8.0 + assert st["by_exchange"]["binance"]["win_count"] == 2 + assert st["by_exchange"]["binance"]["win_rate"] == 100.0 + assert st["by_exchange"]["binance"]["profit_loss_ratio"] is None + + +def test_list_daily_trades_search_filters_stats(): + with tempfile.TemporaryDirectory() as td: + db = Path(td) / "archive.db" + init_db(db) + day = "2023-11-15" + start_ms, _ = trading_day_bounds_ms(day) + btc_close = start_ms + 3_600_000 + eth_close = start_ms + 7_200_000 + upsert_trades_cache( + "gate", + [ + { + "id": 1, + "symbol": "BTC/USDT", + "result": "止盈", + "pnl_amount": 5.0, + "opened_at_ms": start_ms, + "closed_at_ms": btc_close, + }, + { + "id": 2, + "symbol": "ETH/USDT", + "result": "止损", + "pnl_amount": -2.0, + "opened_at_ms": btc_close, + "closed_at_ms": eth_close, + }, + ], + db_path=db, + ) + payload = list_daily_trades( + period="range", + date_from=day, + date_to=day, + search="btc", + db_path=db, + ) + assert len(payload["trades"]) == 1 + assert payload["trades"][0]["symbol"] == "BTC/USDT" + st = payload["stats"] + assert st["open_count"] == 1 + assert st["win_count"] == 1 + assert st["loss_count"] == 0 + assert st["max_win"] == 5.0 + assert st["pnl_total"] == 5.0 diff --git a/tests/test_hub_trades_archive_merge.py b/tests/test_hub_trades_archive_merge.py index d530c09..63e83bc 100644 --- a/tests/test_hub_trades_archive_merge.py +++ b/tests/test_hub_trades_archive_merge.py @@ -1,102 +1,102 @@ -"""档案交易:strategy_trade_snapshots 补全 gate_bot 漏记。""" - -from __future__ import annotations - -import sqlite3 -import tempfile -from datetime import datetime, timedelta -from pathlib import Path - -from hub_trades_lib import fetch_trades_for_archive - - -def _init_db(path: Path) -> sqlite3.Connection: - conn = sqlite3.connect(str(path)) - conn.row_factory = sqlite3.Row - conn.execute( - """ - CREATE TABLE trade_records ( - id INTEGER PRIMARY KEY, - symbol TEXT, - direction TEXT, - result TEXT, - pnl_amount REAL, - opened_at TEXT, - closed_at TEXT, - opened_at_ms INTEGER, - closed_at_ms INTEGER, - created_at TEXT, - trend_plan_id INTEGER - ) - """ - ) - conn.execute( - """ - CREATE TABLE strategy_trade_snapshots ( - id INTEGER PRIMARY KEY, - strategy_type TEXT, - source_id INTEGER, - symbol TEXT, - direction TEXT, - result_label TEXT, - status_at_close TEXT, - opened_at TEXT, - closed_at TEXT, - pnl_amount REAL, - snapshot_json TEXT, - created_at TEXT - ) - """ - ) - return conn - - -def test_merge_snapshot_when_trade_record_missing(): - with tempfile.TemporaryDirectory() as td: - conn = _init_db(Path(td) / "t.db") - closed = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d %H:%M:%S") - conn.execute( - """ - INSERT INTO strategy_trade_snapshots ( - id, strategy_type, source_id, symbol, direction, - result_label, opened_at, closed_at, pnl_amount, snapshot_json, created_at - ) VALUES (?,?,?,?,?,?,?,?,?,?,?) - """, - (7, "trend_pullback", 42, "ONDO/USDT", "long", "止损", closed, closed, -1.2, "{}", closed), - ) - conn.commit() - trades = fetch_trades_for_archive(conn, days=30, limit=50) - conn.close() - assert len(trades) == 1 - assert trades[0]["symbol"] == "ONDO/USDT" - assert trades[0]["id"] == -7 - assert trades[0].get("from_snapshot") is True - - -def test_skip_snapshot_when_trade_record_exists(): - with tempfile.TemporaryDirectory() as td: - conn = _init_db(Path(td) / "t.db") - closed = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d %H:%M:%S") - conn.execute( - """ - INSERT INTO trade_records ( - id, symbol, direction, result, pnl_amount, - opened_at, closed_at, opened_at_ms, closed_at_ms, created_at, trend_plan_id - ) VALUES (?,?,?,?,?,?,?,?,?,?,?) - """, - (1, "ONDO/USDT", "long", "止损", -1.2, closed, closed, 1, 2, closed, 42), - ) - conn.execute( - """ - INSERT INTO strategy_trade_snapshots ( - id, strategy_type, source_id, symbol, direction, - result_label, opened_at, closed_at, pnl_amount, snapshot_json, created_at - ) VALUES (?,?,?,?,?,?,?,?,?,?,?) - """, - (7, "trend_pullback", 42, "ONDO/USDT", "long", "止损", closed, closed, -1.2, "{}", closed), - ) - conn.commit() - trades = fetch_trades_for_archive(conn, days=30, limit=50) - conn.close() - assert len(trades) == 1 - assert trades[0]["id"] == 1 +"""档案交易:strategy_trade_snapshots 补全 gate_bot 漏记。""" + +from __future__ import annotations + +import sqlite3 +import tempfile +from datetime import datetime, timedelta +from pathlib import Path + +from lib.hub.hub_trades_lib import fetch_trades_for_archive + + +def _init_db(path: Path) -> sqlite3.Connection: + conn = sqlite3.connect(str(path)) + conn.row_factory = sqlite3.Row + conn.execute( + """ + CREATE TABLE trade_records ( + id INTEGER PRIMARY KEY, + symbol TEXT, + direction TEXT, + result TEXT, + pnl_amount REAL, + opened_at TEXT, + closed_at TEXT, + opened_at_ms INTEGER, + closed_at_ms INTEGER, + created_at TEXT, + trend_plan_id INTEGER + ) + """ + ) + conn.execute( + """ + CREATE TABLE strategy_trade_snapshots ( + id INTEGER PRIMARY KEY, + strategy_type TEXT, + source_id INTEGER, + symbol TEXT, + direction TEXT, + result_label TEXT, + status_at_close TEXT, + opened_at TEXT, + closed_at TEXT, + pnl_amount REAL, + snapshot_json TEXT, + created_at TEXT + ) + """ + ) + return conn + + +def test_merge_snapshot_when_trade_record_missing(): + with tempfile.TemporaryDirectory() as td: + conn = _init_db(Path(td) / "t.db") + closed = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d %H:%M:%S") + conn.execute( + """ + INSERT INTO strategy_trade_snapshots ( + id, strategy_type, source_id, symbol, direction, + result_label, opened_at, closed_at, pnl_amount, snapshot_json, created_at + ) VALUES (?,?,?,?,?,?,?,?,?,?,?) + """, + (7, "trend_pullback", 42, "ONDO/USDT", "long", "止损", closed, closed, -1.2, "{}", closed), + ) + conn.commit() + trades = fetch_trades_for_archive(conn, days=30, limit=50) + conn.close() + assert len(trades) == 1 + assert trades[0]["symbol"] == "ONDO/USDT" + assert trades[0]["id"] == -7 + assert trades[0].get("from_snapshot") is True + + +def test_skip_snapshot_when_trade_record_exists(): + with tempfile.TemporaryDirectory() as td: + conn = _init_db(Path(td) / "t.db") + closed = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d %H:%M:%S") + conn.execute( + """ + INSERT INTO trade_records ( + id, symbol, direction, result, pnl_amount, + opened_at, closed_at, opened_at_ms, closed_at_ms, created_at, trend_plan_id + ) VALUES (?,?,?,?,?,?,?,?,?,?,?) + """, + (1, "ONDO/USDT", "long", "止损", -1.2, closed, closed, 1, 2, closed, 42), + ) + conn.execute( + """ + INSERT INTO strategy_trade_snapshots ( + id, strategy_type, source_id, symbol, direction, + result_label, opened_at, closed_at, pnl_amount, snapshot_json, created_at + ) VALUES (?,?,?,?,?,?,?,?,?,?,?) + """, + (7, "trend_pullback", 42, "ONDO/USDT", "long", "止损", closed, closed, -1.2, "{}", closed), + ) + conn.commit() + trades = fetch_trades_for_archive(conn, days=30, limit=50) + conn.close() + assert len(trades) == 1 + assert trades[0]["id"] == 1 diff --git a/tests/test_hub_trades_lib.py b/tests/test_hub_trades_lib.py index bd8fa78..4930bfd 100644 --- a/tests/test_hub_trades_lib.py +++ b/tests/test_hub_trades_lib.py @@ -1,198 +1,198 @@ -"""hub_trades_lib 单元测试。""" -from __future__ import annotations - -import sqlite3 -import unittest -from datetime import datetime - -from hub_trades_lib import ( - fetch_trades_for_trading_day, - summarize_trades, - trading_day_from_dt, - trading_day_window_bounds, -) - - -class HubTradesLibTest(unittest.TestCase): - def test_trading_day_reset(self): - dt = datetime(2026, 6, 6, 7, 30, 0) - self.assertEqual(trading_day_from_dt(dt, 8), "2026-06-05") - dt2 = datetime(2026, 6, 6, 8, 0, 0) - self.assertEqual(trading_day_from_dt(dt2, 8), "2026-06-06") - - def test_trading_day_window_bounds(self): - start, end = trading_day_window_bounds("2026-06-06", 8) - self.assertEqual(start, "2026-06-06 08:00:00") - self.assertEqual(end, "2026-06-07 07:59:59") - - def test_fetch_and_summarize(self): - conn = sqlite3.connect(":memory:") - conn.row_factory = sqlite3.Row - conn.execute( - """CREATE TABLE trade_records ( - symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT, - pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL, - closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT, - created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL, - trade_style TEXT, entry_reason TEXT, reviewed_at TEXT - )""" - ) - conn.execute( - "INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", - ( - "ONDO/USDT", - "short", - "止损", - None, - -0.5, - None, - None, - "2026-06-06 10:00:00", - None, - "2026-06-06 09:00:00", - None, - "2026-06-06 10:00:00", - "趋势回调", - None, - None, - "trend", - "", - None, - ), - ) - conn.commit() - rows = fetch_trades_for_trading_day(conn, "2026-06-06") - self.assertEqual(len(rows), 1) - stats = summarize_trades(rows) - self.assertEqual(stats["closed_count"], 1) - self.assertEqual(stats["loss_count"], 1) - self.assertAlmostEqual(stats["total_pnl_u"], -0.5) - conn.close() - - def test_early_morning_belongs_prev_trading_day(self): - conn = sqlite3.connect(":memory:") - conn.row_factory = sqlite3.Row - conn.execute( - """CREATE TABLE trade_records ( - symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT, - pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL, - closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT, - created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL, - trade_style TEXT, entry_reason TEXT, reviewed_at TEXT - )""" - ) - conn.execute( - "INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", - ( - "BTC/USDT", - "long", - "止盈", - None, - 1.2, - None, - None, - "2026-06-07 07:30:00", - None, - "2026-06-07 06:00:00", - None, - "2026-06-07 07:30:00", - "关键位", - None, - None, - "trend", - "", - None, - ), - ) - conn.commit() - self.assertEqual(len(fetch_trades_for_trading_day(conn, "2026-06-07")), 0) - self.assertEqual(len(fetch_trades_for_trading_day(conn, "2026-06-06")), 1) - conn.close() - - def test_reviewed_fields_preferred(self): - conn = sqlite3.connect(":memory:") - conn.row_factory = sqlite3.Row - conn.execute( - """CREATE TABLE trade_records ( - symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT, - pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL, - closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT, - created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL, - trade_style TEXT, entry_reason TEXT, reviewed_at TEXT - )""" - ) - conn.execute( - "INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", - ( - "ETH/USDT", - "long", - "止损", - "止盈", - -0.5, - 2.0, - None, - "2026-06-06 09:00:00", - "2026-06-06 11:00:00", - "2026-06-06 08:00:00", - None, - "2026-06-06 11:00:00", - "趋势回调", - None, - None, - "trend", - "", - "2026-06-06 12:00:00", - ), - ) - conn.commit() - rows = fetch_trades_for_trading_day(conn, "2026-06-06") - self.assertEqual(len(rows), 1) - self.assertEqual(rows[0]["result"], "止盈") - self.assertAlmostEqual(rows[0]["pnl_amount"], 2.0) - self.assertTrue(rows[0]["reviewed"]) - conn.close() - - def test_time_close_result_included(self): - conn = sqlite3.connect(":memory:") - conn.row_factory = sqlite3.Row - conn.execute( - """CREATE TABLE trade_records ( - symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT, - pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL, - closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT, - created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL, - trade_style TEXT, entry_reason TEXT, reviewed_at TEXT - )""" - ) - conn.execute( - "INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", - ( - "BTC/USDT", - "long", - "时间平仓", - None, - 1.2, - None, - None, - "2026-06-06 12:00:00", - None, - "2026-06-06 08:00:00", - None, - "2026-06-06 12:00:00", - "趋势回调", - None, - None, - "trend", - "", - None, - ), - ) - conn.commit() - rows = fetch_trades_for_trading_day(conn, "2026-06-06") - self.assertEqual(len(rows), 1) - self.assertEqual(rows[0]["result"], "时间平仓") - conn.close() - - -if __name__ == "__main__": - unittest.main() +"""hub_trades_lib 单元测试。""" +from __future__ import annotations + +import sqlite3 +import unittest +from datetime import datetime + +from lib.hub.hub_trades_lib import ( + fetch_trades_for_trading_day, + summarize_trades, + trading_day_from_dt, + trading_day_window_bounds, +) + + +class HubTradesLibTest(unittest.TestCase): + def test_trading_day_reset(self): + dt = datetime(2026, 6, 6, 7, 30, 0) + self.assertEqual(trading_day_from_dt(dt, 8), "2026-06-05") + dt2 = datetime(2026, 6, 6, 8, 0, 0) + self.assertEqual(trading_day_from_dt(dt2, 8), "2026-06-06") + + def test_trading_day_window_bounds(self): + start, end = trading_day_window_bounds("2026-06-06", 8) + self.assertEqual(start, "2026-06-06 08:00:00") + self.assertEqual(end, "2026-06-07 07:59:59") + + def test_fetch_and_summarize(self): + conn = sqlite3.connect(":memory:") + conn.row_factory = sqlite3.Row + conn.execute( + """CREATE TABLE trade_records ( + symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT, + pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL, + closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT, + created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL, + trade_style TEXT, entry_reason TEXT, reviewed_at TEXT + )""" + ) + conn.execute( + "INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + ( + "ONDO/USDT", + "short", + "止损", + None, + -0.5, + None, + None, + "2026-06-06 10:00:00", + None, + "2026-06-06 09:00:00", + None, + "2026-06-06 10:00:00", + "趋势回调", + None, + None, + "trend", + "", + None, + ), + ) + conn.commit() + rows = fetch_trades_for_trading_day(conn, "2026-06-06") + self.assertEqual(len(rows), 1) + stats = summarize_trades(rows) + self.assertEqual(stats["closed_count"], 1) + self.assertEqual(stats["loss_count"], 1) + self.assertAlmostEqual(stats["total_pnl_u"], -0.5) + conn.close() + + def test_early_morning_belongs_prev_trading_day(self): + conn = sqlite3.connect(":memory:") + conn.row_factory = sqlite3.Row + conn.execute( + """CREATE TABLE trade_records ( + symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT, + pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL, + closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT, + created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL, + trade_style TEXT, entry_reason TEXT, reviewed_at TEXT + )""" + ) + conn.execute( + "INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + ( + "BTC/USDT", + "long", + "止盈", + None, + 1.2, + None, + None, + "2026-06-07 07:30:00", + None, + "2026-06-07 06:00:00", + None, + "2026-06-07 07:30:00", + "关键位", + None, + None, + "trend", + "", + None, + ), + ) + conn.commit() + self.assertEqual(len(fetch_trades_for_trading_day(conn, "2026-06-07")), 0) + self.assertEqual(len(fetch_trades_for_trading_day(conn, "2026-06-06")), 1) + conn.close() + + def test_reviewed_fields_preferred(self): + conn = sqlite3.connect(":memory:") + conn.row_factory = sqlite3.Row + conn.execute( + """CREATE TABLE trade_records ( + symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT, + pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL, + closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT, + created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL, + trade_style TEXT, entry_reason TEXT, reviewed_at TEXT + )""" + ) + conn.execute( + "INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + ( + "ETH/USDT", + "long", + "止损", + "止盈", + -0.5, + 2.0, + None, + "2026-06-06 09:00:00", + "2026-06-06 11:00:00", + "2026-06-06 08:00:00", + None, + "2026-06-06 11:00:00", + "趋势回调", + None, + None, + "trend", + "", + "2026-06-06 12:00:00", + ), + ) + conn.commit() + rows = fetch_trades_for_trading_day(conn, "2026-06-06") + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0]["result"], "止盈") + self.assertAlmostEqual(rows[0]["pnl_amount"], 2.0) + self.assertTrue(rows[0]["reviewed"]) + conn.close() + + def test_time_close_result_included(self): + conn = sqlite3.connect(":memory:") + conn.row_factory = sqlite3.Row + conn.execute( + """CREATE TABLE trade_records ( + symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT, + pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL, + closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT, + created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL, + trade_style TEXT, entry_reason TEXT, reviewed_at TEXT + )""" + ) + conn.execute( + "INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + ( + "BTC/USDT", + "long", + "时间平仓", + None, + 1.2, + None, + None, + "2026-06-06 12:00:00", + None, + "2026-06-06 08:00:00", + None, + "2026-06-06 12:00:00", + "趋势回调", + None, + None, + "trend", + "", + None, + ), + ) + conn.commit() + rows = fetch_trades_for_trading_day(conn, "2026-06-06") + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0]["result"], "时间平仓") + conn.close() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_hub_trades_review_fields.py b/tests/test_hub_trades_review_fields.py index f758167..66118f0 100644 --- a/tests/test_hub_trades_review_fields.py +++ b/tests/test_hub_trades_review_fields.py @@ -1,115 +1,115 @@ -"""档案交易:复盘字段优先(开仓类型、持仓时长、开平仓时间)。""" - -from __future__ import annotations - -import tempfile -import unittest -from datetime import datetime, timedelta -from pathlib import Path - -from hub_symbol_archive_lib import init_db, load_symbol_trades, upsert_trades_cache -from hub_trades_lib import ( - _normalize_archive_trade_row, - display_entry_type_label, - effective_entry_type, - effective_hold_minutes, -) - - -class TestHubTradesReviewFields(unittest.TestCase): - def test_display_entry_type_for_manual_monitor_review(self): - d = { - "monitor_type": "下单监控", - "entry_reason": "", - "reviewed_entry_reason": "突破回踩", - "reviewed_at": "2026-06-08 10:00:00", - } - self.assertEqual(display_entry_type_label(d), "突破回踩") - - def test_effective_entry_type_prefers_reviewed(self): - d = { - "entry_reason": "突破回踩", - "reviewed_entry_reason": "趋势回调", - "monitor_type": "下单监控", - } - self.assertEqual(effective_entry_type(d), "趋势回调") - - def test_effective_hold_minutes_prefers_reviewed(self): - d = { - "hold_minutes": 30, - "reviewed_hold_minutes": 95, - "opened_at_ms": 1_700_000_000_000, - "closed_at_ms": 1_700_001_800_000, - } - self.assertEqual(effective_hold_minutes(d), 95) - - def test_normalize_archive_trade_row_review_fields(self): - closed = (datetime.now() - timedelta(days=2)).strftime("%Y-%m-%d %H:%M:%S") - opened = (datetime.now() - timedelta(days=2, hours=2)).strftime("%Y-%m-%d %H:%M:%S") - row = _normalize_archive_trade_row( - { - "id": 9, - "symbol": "ONDO/USDT", - "direction": "short", - "result": "止损", - "reviewed_result": "手动平仓", - "pnl_amount": -2.5, - "reviewed_pnl_amount": -2.58, - "opened_at": opened, - "reviewed_opened_at": "2026-06-07 14:30:00", - "closed_at": closed, - "reviewed_closed_at": "2026-06-08 08:44:21", - "opened_at_ms": 1_700_000_000_000, - "closed_at_ms": 1_700_007_200_000, - "entry_reason": "突破回踩", - "reviewed_entry_reason": "趋势回调", - "hold_minutes": 30, - "reviewed_hold_minutes": 1080, - "monitor_type": "趋势回调", - "reviewed_at": closed, - }, - exchange_key="gate", - ) - self.assertIsNotNone(row) - assert row is not None - self.assertEqual(row["entry_type"], "趋势回调") - self.assertEqual(row["hold_minutes"], 1080) - self.assertEqual(row["opened_at"], "2026-06-07 14:30:00") - self.assertEqual(row["closed_at"], "2026-06-08 08:44:21") - self.assertTrue(row["reviewed"]) - - def test_archive_cache_enriches_review_display_fields(self): - with tempfile.TemporaryDirectory() as td: - db = Path(td) / "archive.db" - init_db(db) - upsert_trades_cache( - "gate", - [ - { - "id": 3, - "symbol": "ONDO/USDT", - "direction": "short", - "result": "手动平仓", - "pnl_amount": -2.58, - "opened_at": "2026-06-07 14:30:00", - "closed_at": "2026-06-08 08:44:21", - "opened_at_ms": 1_781_000_000_000, - "closed_at_ms": 1_781_065_000_000, - "entry_type": "趋势回调", - "hold_minutes": 1080, - "hold_minutes_text": "18小时0分钟", - "reviewed": True, - } - ], - db_path=db, - ) - rows = load_symbol_trades("gate", "ONDO/USDT", db_path=db) - self.assertEqual(len(rows), 1) - self.assertEqual(rows[0]["entry_type"], "趋势回调") - self.assertEqual(rows[0]["hold_minutes"], 1080) - self.assertTrue(rows[0]["opened_at"].startswith("2026-06-07")) - self.assertTrue(rows[0]["closed_at"].startswith("2026-06-08")) - - -if __name__ == "__main__": - unittest.main() +"""档案交易:复盘字段优先(开仓类型、持仓时长、开平仓时间)。""" + +from __future__ import annotations + +import tempfile +import unittest +from datetime import datetime, timedelta +from pathlib import Path + +from lib.hub.hub_symbol_archive_lib import init_db, load_symbol_trades, upsert_trades_cache +from lib.hub.hub_trades_lib import ( + _normalize_archive_trade_row, + display_entry_type_label, + effective_entry_type, + effective_hold_minutes, +) + + +class TestHubTradesReviewFields(unittest.TestCase): + def test_display_entry_type_for_manual_monitor_review(self): + d = { + "monitor_type": "下单监控", + "entry_reason": "", + "reviewed_entry_reason": "突破回踩", + "reviewed_at": "2026-06-08 10:00:00", + } + self.assertEqual(display_entry_type_label(d), "突破回踩") + + def test_effective_entry_type_prefers_reviewed(self): + d = { + "entry_reason": "突破回踩", + "reviewed_entry_reason": "趋势回调", + "monitor_type": "下单监控", + } + self.assertEqual(effective_entry_type(d), "趋势回调") + + def test_effective_hold_minutes_prefers_reviewed(self): + d = { + "hold_minutes": 30, + "reviewed_hold_minutes": 95, + "opened_at_ms": 1_700_000_000_000, + "closed_at_ms": 1_700_001_800_000, + } + self.assertEqual(effective_hold_minutes(d), 95) + + def test_normalize_archive_trade_row_review_fields(self): + closed = (datetime.now() - timedelta(days=2)).strftime("%Y-%m-%d %H:%M:%S") + opened = (datetime.now() - timedelta(days=2, hours=2)).strftime("%Y-%m-%d %H:%M:%S") + row = _normalize_archive_trade_row( + { + "id": 9, + "symbol": "ONDO/USDT", + "direction": "short", + "result": "止损", + "reviewed_result": "手动平仓", + "pnl_amount": -2.5, + "reviewed_pnl_amount": -2.58, + "opened_at": opened, + "reviewed_opened_at": "2026-06-07 14:30:00", + "closed_at": closed, + "reviewed_closed_at": "2026-06-08 08:44:21", + "opened_at_ms": 1_700_000_000_000, + "closed_at_ms": 1_700_007_200_000, + "entry_reason": "突破回踩", + "reviewed_entry_reason": "趋势回调", + "hold_minutes": 30, + "reviewed_hold_minutes": 1080, + "monitor_type": "趋势回调", + "reviewed_at": closed, + }, + exchange_key="gate", + ) + self.assertIsNotNone(row) + assert row is not None + self.assertEqual(row["entry_type"], "趋势回调") + self.assertEqual(row["hold_minutes"], 1080) + self.assertEqual(row["opened_at"], "2026-06-07 14:30:00") + self.assertEqual(row["closed_at"], "2026-06-08 08:44:21") + self.assertTrue(row["reviewed"]) + + def test_archive_cache_enriches_review_display_fields(self): + with tempfile.TemporaryDirectory() as td: + db = Path(td) / "archive.db" + init_db(db) + upsert_trades_cache( + "gate", + [ + { + "id": 3, + "symbol": "ONDO/USDT", + "direction": "short", + "result": "手动平仓", + "pnl_amount": -2.58, + "opened_at": "2026-06-07 14:30:00", + "closed_at": "2026-06-08 08:44:21", + "opened_at_ms": 1_781_000_000_000, + "closed_at_ms": 1_781_065_000_000, + "entry_type": "趋势回调", + "hold_minutes": 1080, + "hold_minutes_text": "18小时0分钟", + "reviewed": True, + } + ], + db_path=db, + ) + rows = load_symbol_trades("gate", "ONDO/USDT", db_path=db) + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0]["entry_type"], "趋势回调") + self.assertEqual(rows[0]["hold_minutes"], 1080) + self.assertTrue(rows[0]["opened_at"].startswith("2026-06-07")) + self.assertTrue(rows[0]["closed_at"].startswith("2026-06-08")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_hub_volume_rank_lib.py b/tests/test_hub_volume_rank_lib.py index c0a2fb6..f33b0a5 100644 --- a/tests/test_hub_volume_rank_lib.py +++ b/tests/test_hub_volume_rank_lib.py @@ -1,184 +1,184 @@ -from datetime import datetime -from unittest.mock import MagicMock - -from hub_volume_rank_lib import ( - CACHE_VERSION, - LIQUIDITY_RANK_CACHE_VERSION, - TOP_N_DEFAULT, - _exchange_rank_row_stale, - _okx_turnover_usdt, - _scores_from_binance, - _scores_from_gate, - build_usdt_swap_volume_ranks, - cache_needs_refresh, - format_volume_quote, - merge_exchange_rank, - rank_date_label, - resolve_daily_volume_rank, -) - - -def test_rank_date_label_after_reset(): - # 2026-06-08 09:00 北京时间 → 昨日交易日 2026-06-07 - dt = datetime(2026, 6, 8, 9, 0, 0) - assert rank_date_label(now=dt, reset_hour=8) == "2026-06-07" - - -def test_rank_date_label_before_reset(): - # 2026-06-08 07:00 → 当前交易日仍算 2026-06-07,昨日为 2026-06-06 - dt = datetime(2026, 6, 8, 7, 0, 0) - assert rank_date_label(now=dt, reset_hour=8) == "2026-06-06" - - -def test_format_volume_quote(): - assert format_volume_quote(1_500_000_000) == "1.50B" - assert format_volume_quote(2_300_000) == "2.30M" - assert format_volume_quote(4500) == "4.50K" - - -def test_okx_turnover_usdt(): - qv = _okx_turnover_usdt({"volCcy24h": "100", "last": "50"}) - assert qv == 5000.0 - - -def test_cache_needs_refresh_and_merge(): - cache = {"rank_date": "2026-06-05", "exchanges": {}} - assert cache_needs_refresh(cache, expected_rank_date="2026-06-07") is True - merged = merge_exchange_rank( - cache, - "binance", - { - "ok": True, - "rank_date": "2026-06-07", - "items": [{"rank": 1, "symbol": "BTC/USDT", "volume_quote": 1.0}], - "total_symbols": 100, - }, - ) - assert merged["exchanges"]["binance"]["items"][0]["symbol"] == "BTC/USDT" - assert merged["rank_date"] == "2026-06-07" - - -def test_stale_cache_version_forces_refresh(): - cache = {"version": CACHE_VERSION - 1, "rank_date": "2026-06-07", "exchanges": {"okx": {"items": [{}]}}} - assert cache_needs_refresh(cache) is True - - -def test_short_item_list_is_stale(): - items = [{"rank": i, "symbol": f"S{i}/USDT"} for i in range(1, 13)] - row = {"items": items, "total_symbols": 12} - assert _exchange_rank_row_stale(row) is True - full = {"items": items + [{"rank": i, "symbol": f"X{i}/USDT"} for i in range(13, TOP_N_DEFAULT + 1)], "total_symbols": 300} - assert _exchange_rank_row_stale(full) is False - - -def test_scores_from_binance_uses_fapi_lightweight_api(): - ex = MagicMock() - ex.id = "binance" - ex.fapiPublicGetTicker24hr.return_value = [ - {"symbol": "BTCUSDT", "quoteVolume": "9000000"}, - {"symbol": "ETHUSDT", "quoteVolume": "5000000"}, - ] - scored = _scores_from_binance(ex) - assert scored[0][1] == "BTC" - assert scored[0][2] == 9000000.0 - ex.fetch_tickers.assert_not_called() - - -def test_scores_from_binance_skips_fetch_tickers_on_api_error(): - ex = MagicMock() - ex.id = "binance" - ex.fapiPublicGetTicker24hr.side_effect = RuntimeError("network") - scored = _scores_from_binance(ex) - assert scored == [] - ex.fetch_tickers.assert_not_called() - - -def test_scores_from_gate_uses_futures_tickers_api(): - ex = MagicMock() - ex.id = "gateio" - ex.publicFuturesGetSettleTickers.return_value = [ - {"contract": "BTC_USDT", "volume_24h_quote": "8000000"}, - {"contract": "ETH_USDT", "volume_24h_quote": "4000000"}, - ] - scored = _scores_from_gate(ex) - assert scored[0][1] == "BTC" - ex.fetch_tickers.assert_not_called() - - -def test_scores_from_gate_skips_fetch_tickers_on_api_error(): - ex = MagicMock() - ex.id = "gateio" - ex.publicFuturesGetSettleTickers.side_effect = RuntimeError("network") - scored = _scores_from_gate(ex) - assert scored == [] - ex.fetch_tickers.assert_not_called() - - -def test_resolve_daily_volume_rank_caches_result(): - cache = {"version": 0, "updated_at": 0.0, "ranks": {}, "total": 0} - ex = MagicMock() - ex.id = "binance" - ex.fapiPublicGetTicker24hr.return_value = [ - {"symbol": "BTCUSDT", "quoteVolume": "100"}, - {"symbol": "ETHUSDT", "quoteVolume": "50"}, - ] - - rank, total = resolve_daily_volume_rank( - "BTC", - cache, - now_ts=1000.0, - ttl_sec=60.0, - exchange=ex, - ensure_markets_loaded=lambda: None, - ) - assert rank == 1 - assert total == 2 - assert cache["version"] == LIQUIDITY_RANK_CACHE_VERSION - calls = ex.fapiPublicGetTicker24hr.call_count - - rank2, _ = resolve_daily_volume_rank( - "BTC", - cache, - now_ts=1010.0, - ttl_sec=60.0, - exchange=ex, - ensure_markets_loaded=lambda: None, - ) - assert rank2 == 1 - assert ex.fapiPublicGetTicker24hr.call_count == calls - - -def test_resolve_daily_volume_rank_keeps_stale_cache_when_refresh_empty(): - cache = { - "version": LIQUIDITY_RANK_CACHE_VERSION, - "updated_at": 900.0, - "ranks": {"BTC": 1}, - "total": 100, - } - ex = MagicMock() - ex.id = "binance" - ex.fapiPublicGetTicker24hr.return_value = [] - - rank, total = resolve_daily_volume_rank( - "BTC", - cache, - now_ts=2000.0, - ttl_sec=60.0, - exchange=ex, - ensure_markets_loaded=lambda: None, - ) - assert rank == 1 - assert total == 100 - assert cache["updated_at"] == 900.0 - ex.fetch_tickers.assert_not_called() - - -def test_build_usdt_swap_volume_ranks(): - ex = MagicMock() - ex.id = "binance" - ex.fapiPublicGetTicker24hr.return_value = [ - {"symbol": "SOLUSDT", "quoteVolume": "200"}, - ] - ranks, total = build_usdt_swap_volume_ranks(ex, lambda: None) - assert ranks["SOL"] == 1 - assert total == 1 +from datetime import datetime +from unittest.mock import MagicMock + +from lib.hub.hub_volume_rank_lib import ( + CACHE_VERSION, + LIQUIDITY_RANK_CACHE_VERSION, + TOP_N_DEFAULT, + _exchange_rank_row_stale, + _okx_turnover_usdt, + _scores_from_binance, + _scores_from_gate, + build_usdt_swap_volume_ranks, + cache_needs_refresh, + format_volume_quote, + merge_exchange_rank, + rank_date_label, + resolve_daily_volume_rank, +) + + +def test_rank_date_label_after_reset(): + # 2026-06-08 09:00 北京时间 → 昨日交易日 2026-06-07 + dt = datetime(2026, 6, 8, 9, 0, 0) + assert rank_date_label(now=dt, reset_hour=8) == "2026-06-07" + + +def test_rank_date_label_before_reset(): + # 2026-06-08 07:00 → 当前交易日仍算 2026-06-07,昨日为 2026-06-06 + dt = datetime(2026, 6, 8, 7, 0, 0) + assert rank_date_label(now=dt, reset_hour=8) == "2026-06-06" + + +def test_format_volume_quote(): + assert format_volume_quote(1_500_000_000) == "1.50B" + assert format_volume_quote(2_300_000) == "2.30M" + assert format_volume_quote(4500) == "4.50K" + + +def test_okx_turnover_usdt(): + qv = _okx_turnover_usdt({"volCcy24h": "100", "last": "50"}) + assert qv == 5000.0 + + +def test_cache_needs_refresh_and_merge(): + cache = {"rank_date": "2026-06-05", "exchanges": {}} + assert cache_needs_refresh(cache, expected_rank_date="2026-06-07") is True + merged = merge_exchange_rank( + cache, + "binance", + { + "ok": True, + "rank_date": "2026-06-07", + "items": [{"rank": 1, "symbol": "BTC/USDT", "volume_quote": 1.0}], + "total_symbols": 100, + }, + ) + assert merged["exchanges"]["binance"]["items"][0]["symbol"] == "BTC/USDT" + assert merged["rank_date"] == "2026-06-07" + + +def test_stale_cache_version_forces_refresh(): + cache = {"version": CACHE_VERSION - 1, "rank_date": "2026-06-07", "exchanges": {"okx": {"items": [{}]}}} + assert cache_needs_refresh(cache) is True + + +def test_short_item_list_is_stale(): + items = [{"rank": i, "symbol": f"S{i}/USDT"} for i in range(1, 13)] + row = {"items": items, "total_symbols": 12} + assert _exchange_rank_row_stale(row) is True + full = {"items": items + [{"rank": i, "symbol": f"X{i}/USDT"} for i in range(13, TOP_N_DEFAULT + 1)], "total_symbols": 300} + assert _exchange_rank_row_stale(full) is False + + +def test_scores_from_binance_uses_fapi_lightweight_api(): + ex = MagicMock() + ex.id = "binance" + ex.fapiPublicGetTicker24hr.return_value = [ + {"symbol": "BTCUSDT", "quoteVolume": "9000000"}, + {"symbol": "ETHUSDT", "quoteVolume": "5000000"}, + ] + scored = _scores_from_binance(ex) + assert scored[0][1] == "BTC" + assert scored[0][2] == 9000000.0 + ex.fetch_tickers.assert_not_called() + + +def test_scores_from_binance_skips_fetch_tickers_on_api_error(): + ex = MagicMock() + ex.id = "binance" + ex.fapiPublicGetTicker24hr.side_effect = RuntimeError("network") + scored = _scores_from_binance(ex) + assert scored == [] + ex.fetch_tickers.assert_not_called() + + +def test_scores_from_gate_uses_futures_tickers_api(): + ex = MagicMock() + ex.id = "gateio" + ex.publicFuturesGetSettleTickers.return_value = [ + {"contract": "BTC_USDT", "volume_24h_quote": "8000000"}, + {"contract": "ETH_USDT", "volume_24h_quote": "4000000"}, + ] + scored = _scores_from_gate(ex) + assert scored[0][1] == "BTC" + ex.fetch_tickers.assert_not_called() + + +def test_scores_from_gate_skips_fetch_tickers_on_api_error(): + ex = MagicMock() + ex.id = "gateio" + ex.publicFuturesGetSettleTickers.side_effect = RuntimeError("network") + scored = _scores_from_gate(ex) + assert scored == [] + ex.fetch_tickers.assert_not_called() + + +def test_resolve_daily_volume_rank_caches_result(): + cache = {"version": 0, "updated_at": 0.0, "ranks": {}, "total": 0} + ex = MagicMock() + ex.id = "binance" + ex.fapiPublicGetTicker24hr.return_value = [ + {"symbol": "BTCUSDT", "quoteVolume": "100"}, + {"symbol": "ETHUSDT", "quoteVolume": "50"}, + ] + + rank, total = resolve_daily_volume_rank( + "BTC", + cache, + now_ts=1000.0, + ttl_sec=60.0, + exchange=ex, + ensure_markets_loaded=lambda: None, + ) + assert rank == 1 + assert total == 2 + assert cache["version"] == LIQUIDITY_RANK_CACHE_VERSION + calls = ex.fapiPublicGetTicker24hr.call_count + + rank2, _ = resolve_daily_volume_rank( + "BTC", + cache, + now_ts=1010.0, + ttl_sec=60.0, + exchange=ex, + ensure_markets_loaded=lambda: None, + ) + assert rank2 == 1 + assert ex.fapiPublicGetTicker24hr.call_count == calls + + +def test_resolve_daily_volume_rank_keeps_stale_cache_when_refresh_empty(): + cache = { + "version": LIQUIDITY_RANK_CACHE_VERSION, + "updated_at": 900.0, + "ranks": {"BTC": 1}, + "total": 100, + } + ex = MagicMock() + ex.id = "binance" + ex.fapiPublicGetTicker24hr.return_value = [] + + rank, total = resolve_daily_volume_rank( + "BTC", + cache, + now_ts=2000.0, + ttl_sec=60.0, + exchange=ex, + ensure_markets_loaded=lambda: None, + ) + assert rank == 1 + assert total == 100 + assert cache["updated_at"] == 900.0 + ex.fetch_tickers.assert_not_called() + + +def test_build_usdt_swap_volume_ranks(): + ex = MagicMock() + ex.id = "binance" + ex.fapiPublicGetTicker24hr.return_value = [ + {"symbol": "SOLUSDT", "quoteVolume": "200"}, + ] + ranks, total = build_usdt_swap_volume_ranks(ex, lambda: None) + assert ranks["SOL"] == 1 + assert total == 1 diff --git a/tests/test_instance_embed_context_lib.py b/tests/test_instance_embed_context_lib.py index 79610b6..22683d9 100644 --- a/tests/test_instance_embed_context_lib.py +++ b/tests/test_instance_embed_context_lib.py @@ -1,28 +1,28 @@ -from instance_embed_context_lib import embed_render_plan, trade_records_summary - - -def test_embed_fragment_trade_is_light(): - plan = embed_render_plan("trade", "fragment") - assert plan.exchange_capitals is False - assert plan.records_rows is False - assert plan.records_summary is False - assert plan.orders is True - assert plan.key_history is False - - -def test_embed_shell_trade_summary_only(): - plan = embed_render_plan("trade", "shell") - assert plan.exchange_capitals is True - assert plan.records_summary is True - assert plan.records_rows is False - - -def test_embed_records_page_loads_rows(): - plan = embed_render_plan("records", "fragment") - assert plan.records_rows is True - - -def test_full_page_unchanged(): - plan = embed_render_plan("trade", None) - assert plan.records_rows is True - assert plan.exchange_capitals is True +from lib.instance.instance_embed_context_lib import embed_render_plan, trade_records_summary + + +def test_embed_fragment_trade_is_light(): + plan = embed_render_plan("trade", "fragment") + assert plan.exchange_capitals is False + assert plan.records_rows is False + assert plan.records_summary is False + assert plan.orders is True + assert plan.key_history is False + + +def test_embed_shell_trade_summary_only(): + plan = embed_render_plan("trade", "shell") + assert plan.exchange_capitals is True + assert plan.records_summary is True + assert plan.records_rows is False + + +def test_embed_records_page_loads_rows(): + plan = embed_render_plan("records", "fragment") + assert plan.records_rows is True + + +def test_full_page_unchanged(): + plan = embed_render_plan("trade", None) + assert plan.records_rows is True + assert plan.exchange_capitals is True diff --git a/tests/test_instance_embed_lib.py b/tests/test_instance_embed_lib.py index 72700cc..fc20121 100644 --- a/tests/test_instance_embed_lib.py +++ b/tests/test_instance_embed_lib.py @@ -1,26 +1,26 @@ -from instance_embed_lib import ( - EMBED_TABS, - path_to_embed_tab, - rewrite_embed_dest, -) - - -def test_path_to_embed_tab(): - assert path_to_embed_tab("/trade") == "trade" - assert path_to_embed_tab("/key_monitor") == "key_monitor" - assert path_to_embed_tab("/strategy/records") == "strategy_records" - assert path_to_embed_tab("/unknown") is None - - -def test_rewrite_embed_dest(): - url = rewrite_embed_dest("/trade", hub_theme="dark") - assert url.startswith("/embed?") - assert "tab=trade" in url - assert "embed=1" in url - assert "hub_theme=dark" in url - - -def test_embed_tabs_cover_main_nav(): - assert "trade" in EMBED_TABS - assert "key_monitor" in EMBED_TABS - assert "records" in EMBED_TABS +from lib.instance.instance_embed_lib import ( + EMBED_TABS, + path_to_embed_tab, + rewrite_embed_dest, +) + + +def test_path_to_embed_tab(): + assert path_to_embed_tab("/trade") == "trade" + assert path_to_embed_tab("/key_monitor") == "key_monitor" + assert path_to_embed_tab("/strategy/records") == "strategy_records" + assert path_to_embed_tab("/unknown") is None + + +def test_rewrite_embed_dest(): + url = rewrite_embed_dest("/trade", hub_theme="dark") + assert url.startswith("/embed?") + assert "tab=trade" in url + assert "embed=1" in url + assert "hub_theme=dark" in url + + +def test_embed_tabs_cover_main_nav(): + assert "trade" in EMBED_TABS + assert "key_monitor" in EMBED_TABS + assert "records" in EMBED_TABS diff --git a/tests/test_instance_nav_lib.py b/tests/test_instance_nav_lib.py index 52cfab8..1181ccc 100644 --- a/tests/test_instance_nav_lib.py +++ b/tests/test_instance_nav_lib.py @@ -1,21 +1,21 @@ -from instance_nav_lib import request_is_hub_soft_nav - - -def test_request_is_hub_soft_nav(): - class Req: - args = {"embed": "1"} - headers = {"X-Instance-Soft-Nav": "1"} - - assert request_is_hub_soft_nav(Req()) is True - - class Req2: - args = {"embed": "1"} - headers = {} - - assert request_is_hub_soft_nav(Req2()) is False - - class Req3: - args = {} - headers = {"X-Instance-Soft-Nav": "1"} - - assert request_is_hub_soft_nav(Req3()) is False +from lib.instance.instance_nav_lib import request_is_hub_soft_nav + + +def test_request_is_hub_soft_nav(): + class Req: + args = {"embed": "1"} + headers = {"X-Instance-Soft-Nav": "1"} + + assert request_is_hub_soft_nav(Req()) is True + + class Req2: + args = {"embed": "1"} + headers = {} + + assert request_is_hub_soft_nav(Req2()) is False + + class Req3: + args = {} + headers = {"X-Instance-Soft-Nav": "1"} + + assert request_is_hub_soft_nav(Req3()) is False diff --git a/tests/test_key_monitor_box_invalidate.py b/tests/test_key_monitor_box_invalidate.py index b190015..20302c4 100644 --- a/tests/test_key_monitor_box_invalidate.py +++ b/tests/test_key_monitor_box_invalidate.py @@ -1,34 +1,34 @@ -import unittest - -from key_monitor_lib import ( - BOX_BREAKOUT_CLOSE_OPPOSITE, - box_breakout_invalidate_by_mark, - box_breakout_invalidate_edge_label, -) - - -class BoxBreakoutInvalidateTests(unittest.TestCase): - def test_short_invalidates_above_upper(self): - self.assertTrue(box_breakout_invalidate_by_mark("short", 62.511, 61.746, 60.569)) - - def test_short_stays_valid_inside_or_below(self): - self.assertFalse(box_breakout_invalidate_by_mark("short", 61.0, 61.746, 60.569)) - self.assertFalse(box_breakout_invalidate_by_mark("short", 60.0, 61.746, 60.569)) - - def test_long_invalidates_below_lower(self): - self.assertTrue(box_breakout_invalidate_by_mark("long", 94.0, 100.0, 95.0)) - - def test_long_stays_valid_inside_or_above(self): - self.assertFalse(box_breakout_invalidate_by_mark("long", 98.0, 100.0, 95.0)) - self.assertFalse(box_breakout_invalidate_by_mark("long", 101.0, 100.0, 95.0)) - - def test_edge_label(self): - self.assertEqual(box_breakout_invalidate_edge_label("long"), "下沿") - self.assertEqual(box_breakout_invalidate_edge_label("short"), "上沿") - - def test_close_reason_constant(self): - self.assertEqual(BOX_BREAKOUT_CLOSE_OPPOSITE, "box_opposite_break") - - -if __name__ == "__main__": - unittest.main() +import unittest + +from lib.key_monitor.key_monitor_lib import ( + BOX_BREAKOUT_CLOSE_OPPOSITE, + box_breakout_invalidate_by_mark, + box_breakout_invalidate_edge_label, +) + + +class BoxBreakoutInvalidateTests(unittest.TestCase): + def test_short_invalidates_above_upper(self): + self.assertTrue(box_breakout_invalidate_by_mark("short", 62.511, 61.746, 60.569)) + + def test_short_stays_valid_inside_or_below(self): + self.assertFalse(box_breakout_invalidate_by_mark("short", 61.0, 61.746, 60.569)) + self.assertFalse(box_breakout_invalidate_by_mark("short", 60.0, 61.746, 60.569)) + + def test_long_invalidates_below_lower(self): + self.assertTrue(box_breakout_invalidate_by_mark("long", 94.0, 100.0, 95.0)) + + def test_long_stays_valid_inside_or_above(self): + self.assertFalse(box_breakout_invalidate_by_mark("long", 98.0, 100.0, 95.0)) + self.assertFalse(box_breakout_invalidate_by_mark("long", 101.0, 100.0, 95.0)) + + def test_edge_label(self): + self.assertEqual(box_breakout_invalidate_edge_label("long"), "下沿") + self.assertEqual(box_breakout_invalidate_edge_label("short"), "上沿") + + def test_close_reason_constant(self): + self.assertEqual(BOX_BREAKOUT_CLOSE_OPPOSITE, "box_opposite_break") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_key_monitor_rs_alert.py b/tests/test_key_monitor_rs_alert.py index d2ac828..a125e8e 100644 --- a/tests/test_key_monitor_rs_alert.py +++ b/tests/test_key_monitor_rs_alert.py @@ -1,86 +1,86 @@ -"""阻力/支撑提醒:占位与间隔防重复推送。""" -from __future__ import annotations - -import sqlite3 -import unittest -from datetime import datetime, timedelta - -from key_monitor_lib import ( - claim_rs_level_notify, - notify_interval_elapsed, - run_rs_level_alert_tick, -) - - -def _row(**kwargs): - base = { - "upper": 2.174, - "lower": 1.694, - "notification_count": 0, - "max_notify": 3, - "notify_interval_min": 5, - "direction": "watch", - "last_notified_at": None, - "last_rs_bar_ts": None, - } - base.update(kwargs) - return base - - -class TestRsLevelAlertClaim(unittest.TestCase): - def setUp(self): - self.conn = sqlite3.connect(":memory:") - self.conn.execute( - "CREATE TABLE key_monitors (" - "id INTEGER PRIMARY KEY, notification_count INTEGER DEFAULT 0, " - "direction TEXT, last_notified_at TEXT, last_rs_bar_ts INTEGER)" - ) - self.conn.execute( - "INSERT INTO key_monitors (id, notification_count, direction) VALUES (1, 0, 'watch')" - ) - self.conn.commit() - - def test_claim_advances_once_per_index(self): - ok1 = claim_rs_level_notify( - self.conn, 1, 1, "long", "2026-06-02 00:25:00", 1000, prior_count=0 - ) - self.conn.commit() - self.assertTrue(ok1) - ok_dup = claim_rs_level_notify( - self.conn, 1, 1, "long", "2026-06-02 00:25:03", 1000, prior_count=0 - ) - self.assertFalse(ok_dup) - ok2 = claim_rs_level_notify( - self.conn, 1, 2, "long", "2026-06-02 00:30:00", 1000, prior_count=1 - ) - self.conn.commit() - self.assertTrue(ok2) - row = self.conn.execute( - "SELECT notification_count FROM key_monitors WHERE id=1" - ).fetchone() - self.assertEqual(row[0], 2) - - def test_second_push_requires_interval(self): - now = datetime(2026, 6, 2, 0, 26, 0) - row = _row( - notification_count=1, - direction="long", - last_notified_at="2026-06-02 00:25:00", - ) - tick = run_rs_level_alert_tick(row, 2.18, 1000, now, default_max_notify=3, default_interval_min=5) - self.assertIsNone(tick) - later = datetime(2026, 6, 2, 0, 30, 1) - tick2 = run_rs_level_alert_tick( - row, 2.18, 1000, later, default_max_notify=3, default_interval_min=5 - ) - self.assertIsNotNone(tick2) - self.assertEqual(tick2["notify_index"], 2) - self.assertEqual(tick2["prior_count"], 1) - - def test_notify_interval_invalid_timestamp_does_not_spam(self): - now = datetime(2026, 6, 2, 1, 0, 0) - self.assertFalse(notify_interval_elapsed("not-a-date", 5, now)) - - -if __name__ == "__main__": - unittest.main() +"""阻力/支撑提醒:占位与间隔防重复推送。""" +from __future__ import annotations + +import sqlite3 +import unittest +from datetime import datetime, timedelta + +from lib.key_monitor.key_monitor_lib import ( + claim_rs_level_notify, + notify_interval_elapsed, + run_rs_level_alert_tick, +) + + +def _row(**kwargs): + base = { + "upper": 2.174, + "lower": 1.694, + "notification_count": 0, + "max_notify": 3, + "notify_interval_min": 5, + "direction": "watch", + "last_notified_at": None, + "last_rs_bar_ts": None, + } + base.update(kwargs) + return base + + +class TestRsLevelAlertClaim(unittest.TestCase): + def setUp(self): + self.conn = sqlite3.connect(":memory:") + self.conn.execute( + "CREATE TABLE key_monitors (" + "id INTEGER PRIMARY KEY, notification_count INTEGER DEFAULT 0, " + "direction TEXT, last_notified_at TEXT, last_rs_bar_ts INTEGER)" + ) + self.conn.execute( + "INSERT INTO key_monitors (id, notification_count, direction) VALUES (1, 0, 'watch')" + ) + self.conn.commit() + + def test_claim_advances_once_per_index(self): + ok1 = claim_rs_level_notify( + self.conn, 1, 1, "long", "2026-06-02 00:25:00", 1000, prior_count=0 + ) + self.conn.commit() + self.assertTrue(ok1) + ok_dup = claim_rs_level_notify( + self.conn, 1, 1, "long", "2026-06-02 00:25:03", 1000, prior_count=0 + ) + self.assertFalse(ok_dup) + ok2 = claim_rs_level_notify( + self.conn, 1, 2, "long", "2026-06-02 00:30:00", 1000, prior_count=1 + ) + self.conn.commit() + self.assertTrue(ok2) + row = self.conn.execute( + "SELECT notification_count FROM key_monitors WHERE id=1" + ).fetchone() + self.assertEqual(row[0], 2) + + def test_second_push_requires_interval(self): + now = datetime(2026, 6, 2, 0, 26, 0) + row = _row( + notification_count=1, + direction="long", + last_notified_at="2026-06-02 00:25:00", + ) + tick = run_rs_level_alert_tick(row, 2.18, 1000, now, default_max_notify=3, default_interval_min=5) + self.assertIsNone(tick) + later = datetime(2026, 6, 2, 0, 30, 1) + tick2 = run_rs_level_alert_tick( + row, 2.18, 1000, later, default_max_notify=3, default_interval_min=5 + ) + self.assertIsNotNone(tick2) + self.assertEqual(tick2["notify_index"], 2) + self.assertEqual(tick2["prior_count"], 1) + + def test_notify_interval_invalid_timestamp_does_not_spam(self): + now = datetime(2026, 6, 2, 1, 0, 0) + self.assertFalse(notify_interval_elapsed("not-a-date", 5, now)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_key_monitor_rs_type.py b/tests/test_key_monitor_rs_type.py index 29d486f..5c1ea65 100644 --- a/tests/test_key_monitor_rs_type.py +++ b/tests/test_key_monitor_rs_type.py @@ -1,27 +1,27 @@ -import unittest - -from key_monitor_lib import ( - KEY_MONITOR_RS_TYPE, - is_rs_key_monitor_type, - rs_monitor_type_for_storage, - rs_monitor_type_label, -) - - -class KeyMonitorRsTypeTests(unittest.TestCase): - def test_legacy_types_still_recognized(self): - self.assertTrue(is_rs_key_monitor_type("关键阻力位")) - self.assertTrue(is_rs_key_monitor_type("关键支撑位")) - - def test_storage_normalizes_to_unified_type(self): - self.assertEqual(rs_monitor_type_for_storage("关键阻力位"), KEY_MONITOR_RS_TYPE) - self.assertEqual(rs_monitor_type_for_storage("关键支撑位"), KEY_MONITOR_RS_TYPE) - self.assertEqual(rs_monitor_type_for_storage(KEY_MONITOR_RS_TYPE), KEY_MONITOR_RS_TYPE) - - def test_label_merges_legacy_display(self): - self.assertEqual(rs_monitor_type_label("关键阻力位"), KEY_MONITOR_RS_TYPE) - self.assertEqual(rs_monitor_type_label("箱体突破"), "箱体突破") - - -if __name__ == "__main__": - unittest.main() +import unittest + +from lib.key_monitor.key_monitor_lib import ( + KEY_MONITOR_RS_TYPE, + is_rs_key_monitor_type, + rs_monitor_type_for_storage, + rs_monitor_type_label, +) + + +class KeyMonitorRsTypeTests(unittest.TestCase): + def test_legacy_types_still_recognized(self): + self.assertTrue(is_rs_key_monitor_type("关键阻力位")) + self.assertTrue(is_rs_key_monitor_type("关键支撑位")) + + def test_storage_normalizes_to_unified_type(self): + self.assertEqual(rs_monitor_type_for_storage("关键阻力位"), KEY_MONITOR_RS_TYPE) + self.assertEqual(rs_monitor_type_for_storage("关键支撑位"), KEY_MONITOR_RS_TYPE) + self.assertEqual(rs_monitor_type_for_storage(KEY_MONITOR_RS_TYPE), KEY_MONITOR_RS_TYPE) + + def test_label_merges_legacy_display(self): + self.assertEqual(rs_monitor_type_label("关键阻力位"), KEY_MONITOR_RS_TYPE) + self.assertEqual(rs_monitor_type_label("箱体突破"), "箱体突破") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_manual_sltp_lib.py b/tests/test_manual_sltp_lib.py index e94a757..9b1eba4 100644 --- a/tests/test_manual_sltp_lib.py +++ b/tests/test_manual_sltp_lib.py @@ -1,32 +1,32 @@ -from manual_sltp_lib import ( - MANUAL_FIXED_RR_DEFAULT, - calc_tp_from_fixed_rr, - parse_fixed_rr, - resolve_open_sltp_prices, -) - - -def test_calc_tp_from_fixed_rr_long(): - tp = calc_tp_from_fixed_rr("long", 100.0, 95.0, 1.5) - assert tp == 107.5 - - -def test_calc_tp_from_fixed_rr_short(): - tp = calc_tp_from_fixed_rr("short", 100.0, 105.0, 1.5) - assert tp == 92.5 - - -def test_resolve_open_fixed_rr_mode(): - sl, tp = resolve_open_sltp_prices( - "long", - 100.0, - "fixed_rr", - {"sl": "95", "fixed_rr": "1.5"}, - ) - assert sl == 95.0 - assert tp == 107.5 - - -def test_parse_fixed_rr_default(): - assert parse_fixed_rr(None) == MANUAL_FIXED_RR_DEFAULT - assert parse_fixed_rr("2") == 2.0 +from lib.trade.manual_sltp_lib import ( + MANUAL_FIXED_RR_DEFAULT, + calc_tp_from_fixed_rr, + parse_fixed_rr, + resolve_open_sltp_prices, +) + + +def test_calc_tp_from_fixed_rr_long(): + tp = calc_tp_from_fixed_rr("long", 100.0, 95.0, 1.5) + assert tp == 107.5 + + +def test_calc_tp_from_fixed_rr_short(): + tp = calc_tp_from_fixed_rr("short", 100.0, 105.0, 1.5) + assert tp == 92.5 + + +def test_resolve_open_fixed_rr_mode(): + sl, tp = resolve_open_sltp_prices( + "long", + 100.0, + "fixed_rr", + {"sl": "95", "fixed_rr": "1.5"}, + ) + assert sl == 95.0 + assert tp == 107.5 + + +def test_parse_fixed_rr_default(): + assert parse_fixed_rr(None) == MANUAL_FIXED_RR_DEFAULT + assert parse_fixed_rr("2") == 2.0 diff --git a/tests/test_order_monitor_display_lib.py b/tests/test_order_monitor_display_lib.py index cceb081..67c828d 100644 --- a/tests/test_order_monitor_display_lib.py +++ b/tests/test_order_monitor_display_lib.py @@ -1,102 +1,102 @@ -from order_monitor_display_lib import ( - apply_order_price_display_fields, - is_sl_breakeven_secured, - monitor_open_stop_loss, - order_monitor_tpsl_needs_sync, - resolve_live_tpsl_prices, - sl_breakeven_from_exchange_tpsl, - snapshot_rr, - snapshot_stop_loss, -) - - -def _calc_rr(direction, entry, sl, tp): - if direction == "long": - risk = entry - sl - reward = tp - entry - else: - risk = sl - entry - reward = entry - tp - if risk <= 0 or reward <= 0: - return None - return round(reward / risk, 4) - - -def test_snapshot_stop_loss_prefers_initial(): - assert snapshot_stop_loss(2.45, 2.6) == 2.45 - assert snapshot_stop_loss(None, 2.6) == 2.6 - - -def test_monitor_open_stop_loss_prefers_initial_snapshot(): - row = {"initial_stop_loss": 64000, "stop_loss": 63200} - assert monitor_open_stop_loss(row) == 64000 - - -def test_snapshot_rr_ignores_current_stop_after_manual_move(): - rr = snapshot_rr(_calc_rr, "long", 2.726, 2.45, 2.65, 3.3) - assert rr is not None - assert rr > 2.0 - - -def test_breakeven_long(): - assert is_sl_breakeven_secured("long", 2.726, 2.726) is True - assert is_sl_breakeven_secured("long", 2.726, 2.75) is True - assert is_sl_breakeven_secured("long", 2.726, 2.45) is False - - -def test_breakeven_short(): - assert is_sl_breakeven_secured("short", 72.73, 72.73) is True - assert is_sl_breakeven_secured("short", 72.73, 72.0) is True - assert is_sl_breakeven_secured("short", 72.73, 74.0) is False - - -def test_sl_breakeven_from_exchange_tpsl(): - ok = sl_breakeven_from_exchange_tpsl( - "long", - 2.726, - {"sl": {"trigger_price": 2.735}, "tp": {"trigger_price": 3.3}}, - ) - assert ok is True - - -def test_resolve_live_tpsl_prefers_exchange(): - disp_sl, disp_tp, ex_sl, ex_tp = resolve_live_tpsl_prices( - 1674, - 1647.65, - {"sl": {"trigger_price": 1661}, "tp": {"trigger_price": 1647.65}}, - ) - assert disp_sl == 1661 - assert disp_tp == 1647.65 - assert ex_sl == 1661 - assert ex_tp == 1647.65 - - -def test_order_monitor_tpsl_needs_sync_detects_sl_change(): - new_sl, new_tp, changed = order_monitor_tpsl_needs_sync( - 1674, - 1647.65, - {"sl": {"trigger_price": 1661}, "tp": {"trigger_price": 1647.65}}, - ) - assert changed is True - assert new_sl == 1661 - assert new_tp == 1647.65 - - -def test_apply_order_price_display_fields_live_sl(): - payload = {} - apply_order_price_display_fields( - payload, - direction="short", - entry_price=1663.45, - initial_stop_loss=1674, - stop_loss=1674, - take_profit=1647.65, - calc_rr_ratio_fn=_calc_rr, - exchange_tpsl={"sl": {"trigger_price": 1661}, "tp": {"trigger_price": 1647.65}}, - format_price_fn=lambda _s, v: f"{v:.2f}", - symbol="ETH/USDT:USDT", - ) - assert payload["stop_loss"] == 1661 - assert payload["stop_loss_display"] == "1661.00" - assert payload["sl_breakeven_secured"] is True - assert payload["rr_ratio"] is not None +from lib.trade.order_monitor_display_lib import ( + apply_order_price_display_fields, + is_sl_breakeven_secured, + monitor_open_stop_loss, + order_monitor_tpsl_needs_sync, + resolve_live_tpsl_prices, + sl_breakeven_from_exchange_tpsl, + snapshot_rr, + snapshot_stop_loss, +) + + +def _calc_rr(direction, entry, sl, tp): + if direction == "long": + risk = entry - sl + reward = tp - entry + else: + risk = sl - entry + reward = entry - tp + if risk <= 0 or reward <= 0: + return None + return round(reward / risk, 4) + + +def test_snapshot_stop_loss_prefers_initial(): + assert snapshot_stop_loss(2.45, 2.6) == 2.45 + assert snapshot_stop_loss(None, 2.6) == 2.6 + + +def test_monitor_open_stop_loss_prefers_initial_snapshot(): + row = {"initial_stop_loss": 64000, "stop_loss": 63200} + assert monitor_open_stop_loss(row) == 64000 + + +def test_snapshot_rr_ignores_current_stop_after_manual_move(): + rr = snapshot_rr(_calc_rr, "long", 2.726, 2.45, 2.65, 3.3) + assert rr is not None + assert rr > 2.0 + + +def test_breakeven_long(): + assert is_sl_breakeven_secured("long", 2.726, 2.726) is True + assert is_sl_breakeven_secured("long", 2.726, 2.75) is True + assert is_sl_breakeven_secured("long", 2.726, 2.45) is False + + +def test_breakeven_short(): + assert is_sl_breakeven_secured("short", 72.73, 72.73) is True + assert is_sl_breakeven_secured("short", 72.73, 72.0) is True + assert is_sl_breakeven_secured("short", 72.73, 74.0) is False + + +def test_sl_breakeven_from_exchange_tpsl(): + ok = sl_breakeven_from_exchange_tpsl( + "long", + 2.726, + {"sl": {"trigger_price": 2.735}, "tp": {"trigger_price": 3.3}}, + ) + assert ok is True + + +def test_resolve_live_tpsl_prefers_exchange(): + disp_sl, disp_tp, ex_sl, ex_tp = resolve_live_tpsl_prices( + 1674, + 1647.65, + {"sl": {"trigger_price": 1661}, "tp": {"trigger_price": 1647.65}}, + ) + assert disp_sl == 1661 + assert disp_tp == 1647.65 + assert ex_sl == 1661 + assert ex_tp == 1647.65 + + +def test_order_monitor_tpsl_needs_sync_detects_sl_change(): + new_sl, new_tp, changed = order_monitor_tpsl_needs_sync( + 1674, + 1647.65, + {"sl": {"trigger_price": 1661}, "tp": {"trigger_price": 1647.65}}, + ) + assert changed is True + assert new_sl == 1661 + assert new_tp == 1647.65 + + +def test_apply_order_price_display_fields_live_sl(): + payload = {} + apply_order_price_display_fields( + payload, + direction="short", + entry_price=1663.45, + initial_stop_loss=1674, + stop_loss=1674, + take_profit=1647.65, + calc_rr_ratio_fn=_calc_rr, + exchange_tpsl={"sl": {"trigger_price": 1661}, "tp": {"trigger_price": 1647.65}}, + format_price_fn=lambda _s, v: f"{v:.2f}", + symbol="ETH/USDT:USDT", + ) + assert payload["stop_loss"] == 1661 + assert payload["stop_loss_display"] == "1661.00" + assert payload["sl_breakeven_secured"] is True + assert payload["rr_ratio"] is not None diff --git a/tests/test_position_limit_count.py b/tests/test_position_limit_count.py index 8c68437..f3afd43 100644 --- a/tests/test_position_limit_count.py +++ b/tests/test_position_limit_count.py @@ -1,78 +1,78 @@ -import sqlite3 -import unittest - -from strategy_db import init_strategy_tables -from strategy_trade_labels import ( - MONITOR_TYPE_TREND_PULLBACK, - count_position_limit_active_monitors, -) - - -def _mem_conn(): - conn = sqlite3.connect(":memory:") - conn.row_factory = sqlite3.Row - conn.execute( - """CREATE TABLE order_monitors ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - symbol TEXT, - direction TEXT, - status TEXT, - monitor_type TEXT, - key_signal_type TEXT, - trend_plan_id INTEGER - )""" - ) - init_strategy_tables(conn) - return conn - - -class PositionLimitCountTests(unittest.TestCase): - def test_regular_monitor_counts(self): - conn = _mem_conn() - conn.execute( - "INSERT INTO order_monitors (symbol, status, monitor_type) VALUES ('ETH/USDT', 'active', '下单监控')" - ) - conn.commit() - self.assertEqual(count_position_limit_active_monitors(conn), 1) - - def test_trend_pullback_excluded(self): - conn = _mem_conn() - conn.execute( - """INSERT INTO order_monitors - (symbol, status, monitor_type, trend_plan_id) - VALUES ('ETH/USDT', 'active', ?, 12)""", - (MONITOR_TYPE_TREND_PULLBACK,), - ) - conn.commit() - self.assertEqual(count_position_limit_active_monitors(conn), 0) - - def test_active_roll_group_still_counts_regular_monitor(self): - conn = _mem_conn() - conn.execute( - "INSERT INTO order_monitors (id, symbol, status, monitor_type) VALUES (1, 'ETH/USDT', 'active', '下单监控')" - ) - conn.execute( - """INSERT INTO roll_groups - (order_monitor_id, symbol, direction, status) - VALUES (1, 'ETH/USDT', 'long', 'active')""" - ) - conn.commit() - self.assertEqual(count_position_limit_active_monitors(conn), 1) - - def test_mixed_monitors(self): - conn = _mem_conn() - conn.execute( - "INSERT INTO order_monitors (symbol, status, monitor_type) VALUES ('BTC/USDT', 'active', '下单监控')" - ) - conn.execute( - """INSERT INTO order_monitors - (symbol, status, monitor_type, trend_plan_id) - VALUES ('ETH/USDT', 'active', ?, 3)""", - (MONITOR_TYPE_TREND_PULLBACK,), - ) - conn.commit() - self.assertEqual(count_position_limit_active_monitors(conn), 1) - - -if __name__ == "__main__": - unittest.main() +import sqlite3 +import unittest + +from lib.strategy.strategy_db import init_strategy_tables +from lib.strategy.strategy_trade_labels import ( + MONITOR_TYPE_TREND_PULLBACK, + count_position_limit_active_monitors, +) + + +def _mem_conn(): + conn = sqlite3.connect(":memory:") + conn.row_factory = sqlite3.Row + conn.execute( + """CREATE TABLE order_monitors ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + symbol TEXT, + direction TEXT, + status TEXT, + monitor_type TEXT, + key_signal_type TEXT, + trend_plan_id INTEGER + )""" + ) + init_strategy_tables(conn) + return conn + + +class PositionLimitCountTests(unittest.TestCase): + def test_regular_monitor_counts(self): + conn = _mem_conn() + conn.execute( + "INSERT INTO order_monitors (symbol, status, monitor_type) VALUES ('ETH/USDT', 'active', '下单监控')" + ) + conn.commit() + self.assertEqual(count_position_limit_active_monitors(conn), 1) + + def test_trend_pullback_excluded(self): + conn = _mem_conn() + conn.execute( + """INSERT INTO order_monitors + (symbol, status, monitor_type, trend_plan_id) + VALUES ('ETH/USDT', 'active', ?, 12)""", + (MONITOR_TYPE_TREND_PULLBACK,), + ) + conn.commit() + self.assertEqual(count_position_limit_active_monitors(conn), 0) + + def test_active_roll_group_still_counts_regular_monitor(self): + conn = _mem_conn() + conn.execute( + "INSERT INTO order_monitors (id, symbol, status, monitor_type) VALUES (1, 'ETH/USDT', 'active', '下单监控')" + ) + conn.execute( + """INSERT INTO roll_groups + (order_monitor_id, symbol, direction, status) + VALUES (1, 'ETH/USDT', 'long', 'active')""" + ) + conn.commit() + self.assertEqual(count_position_limit_active_monitors(conn), 1) + + def test_mixed_monitors(self): + conn = _mem_conn() + conn.execute( + "INSERT INTO order_monitors (symbol, status, monitor_type) VALUES ('BTC/USDT', 'active', '下单监控')" + ) + conn.execute( + """INSERT INTO order_monitors + (symbol, status, monitor_type, trend_plan_id) + VALUES ('ETH/USDT', 'active', ?, 3)""", + (MONITOR_TYPE_TREND_PULLBACK,), + ) + conn.commit() + self.assertEqual(count_position_limit_active_monitors(conn), 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_position_sizing_risk_display.py b/tests/test_position_sizing_risk_display.py index 6b28eda..6022d40 100644 --- a/tests/test_position_sizing_risk_display.py +++ b/tests/test_position_sizing_risk_display.py @@ -1,34 +1,34 @@ -"""全仓 / 以损定仓 风险展示文案。""" -from __future__ import annotations - -import sys -import unittest -from pathlib import Path - -ROOT = Path(__file__).resolve().parents[1] -sys.path.insert(0, str(ROOT)) - -from position_sizing_lib import ( # noqa: E402 - format_risk_display_text, - risk_percent_for_storage, -) - - -class TestPositionSizingRiskDisplay(unittest.TestCase): - def test_full_margin_shows_amount_only(self): - self.assertEqual( - format_risk_display_text("full_margin", 1.0, 2.58, decimals=2), - "2.58U", - ) - self.assertIsNone(risk_percent_for_storage("full_margin", 1.0)) - - def test_risk_mode_shows_percent_and_amount(self): - self.assertEqual( - format_risk_display_text("risk", 2.0, 10.5, decimals=2), - "2%≈10.5U", - ) - self.assertEqual(risk_percent_for_storage("risk", 2.0), 2.0) - - -if __name__ == "__main__": - unittest.main() +"""全仓 / 以损定仓 风险展示文案。""" +from __future__ import annotations + +import sys +import unittest +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT)) + +from lib.trade.position_sizing_lib import ( # noqa: E402 + format_risk_display_text, + risk_percent_for_storage, +) + + +class TestPositionSizingRiskDisplay(unittest.TestCase): + def test_full_margin_shows_amount_only(self): + self.assertEqual( + format_risk_display_text("full_margin", 1.0, 2.58, decimals=2), + "2.58U", + ) + self.assertIsNone(risk_percent_for_storage("full_margin", 1.0)) + + def test_risk_mode_shows_percent_and_amount(self): + self.assertEqual( + format_risk_display_text("risk", 2.0, 10.5, decimals=2), + "2%≈10.5U", + ) + self.assertEqual(risk_percent_for_storage("risk", 2.0), 2.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_strategy_roll_lib.py b/tests/test_strategy_roll_lib.py index 502c96d..db35285 100644 --- a/tests/test_strategy_roll_lib.py +++ b/tests/test_strategy_roll_lib.py @@ -1,112 +1,112 @@ -from strategy_roll_lib import ( - preview_roll, - roll_breakout_invalidate, - roll_breakout_trigger_crossed, - roll_fib_invalidate, - roll_fib_trigger_crossed, - solve_add_amount_for_total_risk, - validate_roll_geometry, -) - - -def test_solve_add_amount_long_one_risk(): - q2, err = solve_add_amount_for_total_risk( - "long", 1.0, 3000.0, 3100.0, 2950.0, 200.0, 1.0 - ) - assert err is None - avg = (1 * 3000 + q2 * 3100) / (1 + q2) - loss = (avg - 2950) * (1 + q2) - assert abs(loss - 200.0) < 0.01 - - -def test_preview_roll_market_short(): - preview, err = preview_roll( - direction="short", - symbol="HYPE/USDT", - qty_existing=3.0, - entry_existing=65.0, - initial_take_profit=60.0, - add_mode="market", - new_stop_loss=66.5, - risk_percent=2.0, - capital_base_usdt=1000.0, - add_price=64.0, - legs_done=1, - ) - assert err is None - assert preview["add_mode_label"] == "市价加仓" - sl = preview["new_stop_loss"] - avg = preview["avg_entry_after"] - qty = preview["qty_after"] - loss = (sl - avg) * qty - assert abs(loss - 20.0) < 0.01 - - -def test_fib_cross_long_down(): - assert roll_fib_trigger_crossed("long", 101.0, 100.0, 100.5) is True - assert roll_fib_trigger_crossed("long", 100.6, 100.6, 100.5) is False - - -def test_breakout_cross_long_up(): - assert roll_breakout_trigger_crossed("long", 99.0, 100.5, 100.0) is True - assert roll_breakout_invalidate("long", 98.0, 99.0) is True - assert roll_fib_invalidate("long", 110.0, 105.0, 95.0) is True - - -def test_preview_breakout_mode_label(): - preview, err = preview_roll( - direction="long", - symbol="ETH/USDT", - qty_existing=1.0, - entry_existing=3000.0, - initial_take_profit=3500.0, - add_mode="breakout", - new_stop_loss=2980.0, - breakthrough_price=3100.0, - risk_percent=10.0, - capital_base_usdt=1000.0, - add_price=3050.0, - ) - assert err is None - assert preview["add_mode_label"] == "突破加仓" - - -def test_breakout_geometry_short_mark_above_breakout(): - err = validate_roll_geometry( - "short", - "breakout", - new_stop_loss=568.0, - breakthrough_price=551.0, - entry_existing=560.0, - initial_take_profit=540.0, - mark_price=560.0, - ) - assert err is None - - -def test_breakout_geometry_short_rejects_mark_at_or_below_breakout(): - err = validate_roll_geometry( - "short", - "breakout", - new_stop_loss=568.0, - breakthrough_price=551.0, - entry_existing=560.0, - initial_take_profit=540.0, - mark_price=551.0, - ) - assert err is not None - assert "高于突破价" in err - - -def test_breakout_geometry_long_rejects_mark_at_or_above_breakout(): - err = validate_roll_geometry( - "long", - "breakout", - new_stop_loss=2980.0, - breakthrough_price=3100.0, - entry_existing=3000.0, - initial_take_profit=3500.0, - mark_price=3100.0, - ) - assert err is not None - assert "低于突破价" in err +from lib.strategy.strategy_roll_lib import ( + preview_roll, + roll_breakout_invalidate, + roll_breakout_trigger_crossed, + roll_fib_invalidate, + roll_fib_trigger_crossed, + solve_add_amount_for_total_risk, + validate_roll_geometry, +) + + +def test_solve_add_amount_long_one_risk(): + q2, err = solve_add_amount_for_total_risk( + "long", 1.0, 3000.0, 3100.0, 2950.0, 200.0, 1.0 + ) + assert err is None + avg = (1 * 3000 + q2 * 3100) / (1 + q2) + loss = (avg - 2950) * (1 + q2) + assert abs(loss - 200.0) < 0.01 + + +def test_preview_roll_market_short(): + preview, err = preview_roll( + direction="short", + symbol="HYPE/USDT", + qty_existing=3.0, + entry_existing=65.0, + initial_take_profit=60.0, + add_mode="market", + new_stop_loss=66.5, + risk_percent=2.0, + capital_base_usdt=1000.0, + add_price=64.0, + legs_done=1, + ) + assert err is None + assert preview["add_mode_label"] == "市价加仓" + sl = preview["new_stop_loss"] + avg = preview["avg_entry_after"] + qty = preview["qty_after"] + loss = (sl - avg) * qty + assert abs(loss - 20.0) < 0.01 + + +def test_fib_cross_long_down(): + assert roll_fib_trigger_crossed("long", 101.0, 100.0, 100.5) is True + assert roll_fib_trigger_crossed("long", 100.6, 100.6, 100.5) is False + + +def test_breakout_cross_long_up(): + assert roll_breakout_trigger_crossed("long", 99.0, 100.5, 100.0) is True + assert roll_breakout_invalidate("long", 98.0, 99.0) is True + assert roll_fib_invalidate("long", 110.0, 105.0, 95.0) is True + + +def test_preview_breakout_mode_label(): + preview, err = preview_roll( + direction="long", + symbol="ETH/USDT", + qty_existing=1.0, + entry_existing=3000.0, + initial_take_profit=3500.0, + add_mode="breakout", + new_stop_loss=2980.0, + breakthrough_price=3100.0, + risk_percent=10.0, + capital_base_usdt=1000.0, + add_price=3050.0, + ) + assert err is None + assert preview["add_mode_label"] == "突破加仓" + + +def test_breakout_geometry_short_mark_above_breakout(): + err = validate_roll_geometry( + "short", + "breakout", + new_stop_loss=568.0, + breakthrough_price=551.0, + entry_existing=560.0, + initial_take_profit=540.0, + mark_price=560.0, + ) + assert err is None + + +def test_breakout_geometry_short_rejects_mark_at_or_below_breakout(): + err = validate_roll_geometry( + "short", + "breakout", + new_stop_loss=568.0, + breakthrough_price=551.0, + entry_existing=560.0, + initial_take_profit=540.0, + mark_price=551.0, + ) + assert err is not None + assert "高于突破价" in err + + +def test_breakout_geometry_long_rejects_mark_at_or_above_breakout(): + err = validate_roll_geometry( + "long", + "breakout", + new_stop_loss=2980.0, + breakthrough_price=3100.0, + entry_existing=3000.0, + initial_take_profit=3500.0, + mark_price=3100.0, + ) + assert err is not None + assert "低于突破价" in err diff --git a/tests/test_strategy_roll_ui_lib.py b/tests/test_strategy_roll_ui_lib.py index 05b00e3..61b423d 100644 --- a/tests/test_strategy_roll_ui_lib.py +++ b/tests/test_strategy_roll_ui_lib.py @@ -1,44 +1,44 @@ -"""strategy_roll_ui_lib 单元测试。""" -from __future__ import annotations - -import sys -from pathlib import Path - -ROOT = Path(__file__).resolve().parents[1] -sys.path.insert(0, str(ROOT)) - -import strategy_roll_ui_lib as roll_ui - - -def test_compute_roll_chain_metrics_short(): - group = { - "id": 1, - "direction": "short", - "initial_take_profit": 60.0, - } - legs = [ - {"id": 10, "leg_index": 1, "amount": 3.0, "fill_price": 65.0, "status": "filled"}, - {"id": 11, "leg_index": 2, "amount": 5.0, "fill_price": 64.0, "status": "filled"}, - ] - per_leg, group_metrics = roll_ui.compute_roll_chain_metrics( - group, - legs, - qty_live=8.0, - entry_live=63.5, - monitor={"trigger_price": 66.0, "order_amount": 3.0}, - ) - assert per_leg[10]["avg_entry_after"] is not None - assert per_leg[11]["avg_entry_after"] is not None - assert group_metrics["reward_at_tp_usdt"] is not None - assert per_leg[11]["reward_at_tp_usdt"] >= per_leg[10]["reward_at_tp_usdt"] - - -def test_infer_initial_position_from_live(): - legs = [{"amount": 2.0, "fill_price": 64.0, "status": "filled"}] - q0, e0 = roll_ui.infer_initial_position(5.0, 63.0, legs) - assert q0 == 3.0 - assert abs(e0 - 62.3333333333) < 0.001 - - -def test_reward_at_tp_long(): - assert roll_ui.reward_at_tp_usdt("long", 100.0, 110.0, 2.0) == 20.0 +"""strategy_roll_ui_lib 单元测试。""" +from __future__ import annotations + +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT)) + +import lib.strategy.strategy_roll_ui_lib as roll_ui + + +def test_compute_roll_chain_metrics_short(): + group = { + "id": 1, + "direction": "short", + "initial_take_profit": 60.0, + } + legs = [ + {"id": 10, "leg_index": 1, "amount": 3.0, "fill_price": 65.0, "status": "filled"}, + {"id": 11, "leg_index": 2, "amount": 5.0, "fill_price": 64.0, "status": "filled"}, + ] + per_leg, group_metrics = roll_ui.compute_roll_chain_metrics( + group, + legs, + qty_live=8.0, + entry_live=63.5, + monitor={"trigger_price": 66.0, "order_amount": 3.0}, + ) + assert per_leg[10]["avg_entry_after"] is not None + assert per_leg[11]["avg_entry_after"] is not None + assert group_metrics["reward_at_tp_usdt"] is not None + assert per_leg[11]["reward_at_tp_usdt"] >= per_leg[10]["reward_at_tp_usdt"] + + +def test_infer_initial_position_from_live(): + legs = [{"amount": 2.0, "fill_price": 64.0, "status": "filled"}] + q0, e0 = roll_ui.infer_initial_position(5.0, 63.0, legs) + assert q0 == 3.0 + assert abs(e0 - 62.3333333333) < 0.001 + + +def test_reward_at_tp_long(): + assert roll_ui.reward_at_tp_usdt("long", 100.0, 110.0, 2.0) == 20.0 diff --git a/tests/test_strategy_snapshot_dedup.py b/tests/test_strategy_snapshot_dedup.py index 6395159..ab32507 100644 --- a/tests/test_strategy_snapshot_dedup.py +++ b/tests/test_strategy_snapshot_dedup.py @@ -1,183 +1,183 @@ -"""策略快照:同一计划同结果不重复写入。""" -from __future__ import annotations - -import json -import sqlite3 -import sys -from pathlib import Path - -ROOT = Path(__file__).resolve().parents[1] -sys.path.insert(0, str(ROOT)) - -from strategy_snapshot_lib import ( # noqa: E402 - STRATEGY_TREND, - dedupe_strategy_snapshots, - init_strategy_snapshot_table, - list_strategy_snapshots, - save_trend_plan_snapshot, -) - - -def _mem_conn() -> sqlite3.Connection: - conn = sqlite3.connect(":memory:") - conn.row_factory = sqlite3.Row - init_strategy_snapshot_table(conn) - return conn - - -def test_save_trend_plan_snapshot_skips_duplicate_result(): - conn = _mem_conn() - plan = { - "id": 42, - "symbol": "ONDO/USDT", - "exchange_symbol": "ONDO/USDT:USDT", - "direction": "short", - "status": "active", - "opened_at": "2026-06-08 08:00:00", - "legs_done": 4, - "dca_legs": 4, - "first_order_done": 1, - "grid_prices_json": "[]", - "leg_amounts_json": "[]", - } - cfg = {"app_module": type("M", (), {"app_now_str": staticmethod(lambda: "2026-06-08 08:41:00")})()} - save_trend_plan_snapshot(cfg, conn, plan, result_label="止损", pnl_amount=-2.3) - save_trend_plan_snapshot(cfg, conn, plan, result_label="止损", pnl_amount=-2.4) - conn.commit() - rows = conn.execute( - "SELECT COUNT(*) AS c FROM strategy_trade_snapshots WHERE source_id=? AND result_label=?", - (42, "止损"), - ).fetchone() - assert int(rows["c"]) == 1 - - -def test_dedupe_strategy_snapshots_handles_many_duplicates(): - conn = _mem_conn() - payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False) - for snap_id in range(1, 46): - conn.execute( - """INSERT INTO strategy_trade_snapshots ( - id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount - ) VALUES (?,?,?,?,?,?,?,?,?)""", - ( - snap_id, - STRATEGY_TREND, - 99, - "ONDO/USDT", - "止损", - payload, - "2026-06-08 08:41:00", - "2026-06-08 08:41:00", - -2.2, - ), - ) - conn.commit() - removed = dedupe_strategy_snapshots(conn) - conn.commit() - assert removed == 44 - row = conn.execute( - "SELECT COUNT(*) AS c FROM strategy_trade_snapshots WHERE source_id=?", - (99,), - ).fetchone() - assert int(row["c"]) == 1 - - -def test_dedupe_strategy_snapshots_keeps_latest_id(): - conn = _mem_conn() - payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False) - for snap_id, pnl in ((1, -2.23), (2, -2.31), (3, -2.38)): - conn.execute( - """INSERT INTO strategy_trade_snapshots ( - id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount - ) VALUES (?,?,?,?,?,?,?,?,?)""", - ( - snap_id, - STRATEGY_TREND, - 5, - "ONDO/USDT", - "止损", - payload, - "2026-06-08 08:41:00", - "2026-06-08 08:41:00", - pnl, - ), - ) - conn.commit() - removed = dedupe_strategy_snapshots(conn) - conn.commit() - assert removed == 2 - row = conn.execute( - "SELECT id, pnl_amount FROM strategy_trade_snapshots WHERE source_id=?", - (5,), - ).fetchone() - assert int(row["id"]) == 3 - assert abs(float(row["pnl_amount"]) - (-2.38)) < 1e-6 - - -def test_list_strategy_snapshots_hides_duplicate_keys(): - conn = _mem_conn() - payload = json.dumps({"symbol": "ONDO/USDT", "dca_levels": []}, ensure_ascii=False) - for snap_id in (10, 11, 12): - conn.execute( - """INSERT INTO strategy_trade_snapshots ( - id, strategy_type, source_id, symbol, direction, result_label, - snapshot_json, closed_at, created_at, pnl_amount - ) VALUES (?,?,?,?,?,?,?,?,?,?)""", - ( - snap_id, - STRATEGY_TREND, - 7, - "ONDO/USDT", - "short", - "止损", - payload, - "2026-06-08 08:41:00", - "2026-06-08 08:41:00", - -2.2, - ), - ) - conn.commit() - rows = list_strategy_snapshots(conn, limit=50) - stop_rows = [r for r in rows if int(r.get("source_id") or 0) == 7] - assert len(stop_rows) == 1 - assert int(stop_rows[0]["id"]) == 12 - - -def test_dedupe_keeps_manual_over_stop_loss(): - conn = _mem_conn() - payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False) - for snap_id, label in ((10, "止损"), (11, "手动平仓")): - conn.execute( - """INSERT INTO strategy_trade_snapshots ( - id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount - ) VALUES (?,?,?,?,?,?,?,?,?)""", - ( - snap_id, - STRATEGY_TREND, - 7, - "ONDO/USDT", - label, - payload, - "2026-06-08 08:44:00", - "2026-06-08 08:44:00", - -2.23, - ), - ) - conn.commit() - removed = dedupe_strategy_snapshots(conn) - conn.commit() - assert removed == 1 - row = conn.execute( - "SELECT result_label FROM strategy_trade_snapshots WHERE source_id=?", - (7,), - ).fetchone() - assert row["result_label"] == "手动平仓" - - -if __name__ == "__main__": - test_save_trend_plan_snapshot_skips_duplicate_result() - test_dedupe_strategy_snapshots_handles_many_duplicates() - test_dedupe_strategy_snapshots_keeps_latest_id() - test_list_strategy_snapshots_hides_duplicate_keys() - test_dedupe_keeps_manual_over_stop_loss() - print("all ok") +"""策略快照:同一计划同结果不重复写入。""" +from __future__ import annotations + +import json +import sqlite3 +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT)) + +from lib.strategy.strategy_snapshot_lib import ( # noqa: E402 + STRATEGY_TREND, + dedupe_strategy_snapshots, + init_strategy_snapshot_table, + list_strategy_snapshots, + save_trend_plan_snapshot, +) + + +def _mem_conn() -> sqlite3.Connection: + conn = sqlite3.connect(":memory:") + conn.row_factory = sqlite3.Row + init_strategy_snapshot_table(conn) + return conn + + +def test_save_trend_plan_snapshot_skips_duplicate_result(): + conn = _mem_conn() + plan = { + "id": 42, + "symbol": "ONDO/USDT", + "exchange_symbol": "ONDO/USDT:USDT", + "direction": "short", + "status": "active", + "opened_at": "2026-06-08 08:00:00", + "legs_done": 4, + "dca_legs": 4, + "first_order_done": 1, + "grid_prices_json": "[]", + "leg_amounts_json": "[]", + } + cfg = {"app_module": type("M", (), {"app_now_str": staticmethod(lambda: "2026-06-08 08:41:00")})()} + save_trend_plan_snapshot(cfg, conn, plan, result_label="止损", pnl_amount=-2.3) + save_trend_plan_snapshot(cfg, conn, plan, result_label="止损", pnl_amount=-2.4) + conn.commit() + rows = conn.execute( + "SELECT COUNT(*) AS c FROM strategy_trade_snapshots WHERE source_id=? AND result_label=?", + (42, "止损"), + ).fetchone() + assert int(rows["c"]) == 1 + + +def test_dedupe_strategy_snapshots_handles_many_duplicates(): + conn = _mem_conn() + payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False) + for snap_id in range(1, 46): + conn.execute( + """INSERT INTO strategy_trade_snapshots ( + id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount + ) VALUES (?,?,?,?,?,?,?,?,?)""", + ( + snap_id, + STRATEGY_TREND, + 99, + "ONDO/USDT", + "止损", + payload, + "2026-06-08 08:41:00", + "2026-06-08 08:41:00", + -2.2, + ), + ) + conn.commit() + removed = dedupe_strategy_snapshots(conn) + conn.commit() + assert removed == 44 + row = conn.execute( + "SELECT COUNT(*) AS c FROM strategy_trade_snapshots WHERE source_id=?", + (99,), + ).fetchone() + assert int(row["c"]) == 1 + + +def test_dedupe_strategy_snapshots_keeps_latest_id(): + conn = _mem_conn() + payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False) + for snap_id, pnl in ((1, -2.23), (2, -2.31), (3, -2.38)): + conn.execute( + """INSERT INTO strategy_trade_snapshots ( + id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount + ) VALUES (?,?,?,?,?,?,?,?,?)""", + ( + snap_id, + STRATEGY_TREND, + 5, + "ONDO/USDT", + "止损", + payload, + "2026-06-08 08:41:00", + "2026-06-08 08:41:00", + pnl, + ), + ) + conn.commit() + removed = dedupe_strategy_snapshots(conn) + conn.commit() + assert removed == 2 + row = conn.execute( + "SELECT id, pnl_amount FROM strategy_trade_snapshots WHERE source_id=?", + (5,), + ).fetchone() + assert int(row["id"]) == 3 + assert abs(float(row["pnl_amount"]) - (-2.38)) < 1e-6 + + +def test_list_strategy_snapshots_hides_duplicate_keys(): + conn = _mem_conn() + payload = json.dumps({"symbol": "ONDO/USDT", "dca_levels": []}, ensure_ascii=False) + for snap_id in (10, 11, 12): + conn.execute( + """INSERT INTO strategy_trade_snapshots ( + id, strategy_type, source_id, symbol, direction, result_label, + snapshot_json, closed_at, created_at, pnl_amount + ) VALUES (?,?,?,?,?,?,?,?,?,?)""", + ( + snap_id, + STRATEGY_TREND, + 7, + "ONDO/USDT", + "short", + "止损", + payload, + "2026-06-08 08:41:00", + "2026-06-08 08:41:00", + -2.2, + ), + ) + conn.commit() + rows = list_strategy_snapshots(conn, limit=50) + stop_rows = [r for r in rows if int(r.get("source_id") or 0) == 7] + assert len(stop_rows) == 1 + assert int(stop_rows[0]["id"]) == 12 + + +def test_dedupe_keeps_manual_over_stop_loss(): + conn = _mem_conn() + payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False) + for snap_id, label in ((10, "止损"), (11, "手动平仓")): + conn.execute( + """INSERT INTO strategy_trade_snapshots ( + id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount + ) VALUES (?,?,?,?,?,?,?,?,?)""", + ( + snap_id, + STRATEGY_TREND, + 7, + "ONDO/USDT", + label, + payload, + "2026-06-08 08:44:00", + "2026-06-08 08:44:00", + -2.23, + ), + ) + conn.commit() + removed = dedupe_strategy_snapshots(conn) + conn.commit() + assert removed == 1 + row = conn.execute( + "SELECT result_label FROM strategy_trade_snapshots WHERE source_id=?", + (7,), + ).fetchone() + assert row["result_label"] == "手动平仓" + + +if __name__ == "__main__": + test_save_trend_plan_snapshot_skips_duplicate_result() + test_dedupe_strategy_snapshots_handles_many_duplicates() + test_dedupe_strategy_snapshots_keeps_latest_id() + test_list_strategy_snapshots_hides_duplicate_keys() + test_dedupe_keeps_manual_over_stop_loss() + print("all ok") diff --git a/tests/test_trade_exchange_stats_lib.py b/tests/test_trade_exchange_stats_lib.py index c873a15..1b335ee 100644 --- a/tests/test_trade_exchange_stats_lib.py +++ b/tests/test_trade_exchange_stats_lib.py @@ -1,48 +1,48 @@ -import unittest - -from trade_exchange_stats_lib import ( - aggregate_bilateral_stats, - commission_usdt_from_fill, - filter_position_lifecycle_fills, - merge_commission_prefer_income, - quote_turnover_usdt_from_fill, -) - - -class TradeExchangeStatsTests(unittest.TestCase): - def test_turnover_from_cost(self): - t = {"cost": 1000.0, "price": 50, "amount": 20} - self.assertEqual(quote_turnover_usdt_from_fill(t), 1000.0) - - def test_commission_from_fee(self): - t = {"fee": {"cost": -0.42, "currency": "USDT"}} - self.assertEqual(commission_usdt_from_fill(t), 0.42) - - def test_bilateral_aggregate(self): - fills = [ - {"side": "buy", "cost": 500, "fee": {"cost": -0.2, "currency": "USDT"}, "timestamp": 1000}, - {"side": "sell", "cost": 520, "fee": {"cost": -0.21, "currency": "USDT"}, "timestamp": 2000}, - ] - stats = aggregate_bilateral_stats(fills) - self.assertIsNotNone(stats) - self.assertEqual(stats["exchange_turnover_usdt"], 1020.0) - self.assertEqual(stats["exchange_commission_usdt"], 0.41) - - def test_filter_long_lifecycle(self): - base = 1_700_000_000_000 - trades = [ - {"side": "buy", "timestamp": base, "cost": 100}, - {"side": "sell", "timestamp": base + 60_000, "cost": 110}, - {"side": "buy", "timestamp": base + 120_000, "cost": 999}, - ] - got = filter_position_lifecycle_fills( - trades, "long", base - 1000, base + 90_000, close_buffer_ms=0 - ) - self.assertEqual(len(got), 2) - - def test_prefer_income_commission(self): - self.assertEqual(merge_commission_prefer_income(0.3, 0.45), 0.45) - - -if __name__ == "__main__": - unittest.main() +import unittest + +from lib.trade.trade_exchange_stats_lib import ( + aggregate_bilateral_stats, + commission_usdt_from_fill, + filter_position_lifecycle_fills, + merge_commission_prefer_income, + quote_turnover_usdt_from_fill, +) + + +class TradeExchangeStatsTests(unittest.TestCase): + def test_turnover_from_cost(self): + t = {"cost": 1000.0, "price": 50, "amount": 20} + self.assertEqual(quote_turnover_usdt_from_fill(t), 1000.0) + + def test_commission_from_fee(self): + t = {"fee": {"cost": -0.42, "currency": "USDT"}} + self.assertEqual(commission_usdt_from_fill(t), 0.42) + + def test_bilateral_aggregate(self): + fills = [ + {"side": "buy", "cost": 500, "fee": {"cost": -0.2, "currency": "USDT"}, "timestamp": 1000}, + {"side": "sell", "cost": 520, "fee": {"cost": -0.21, "currency": "USDT"}, "timestamp": 2000}, + ] + stats = aggregate_bilateral_stats(fills) + self.assertIsNotNone(stats) + self.assertEqual(stats["exchange_turnover_usdt"], 1020.0) + self.assertEqual(stats["exchange_commission_usdt"], 0.41) + + def test_filter_long_lifecycle(self): + base = 1_700_000_000_000 + trades = [ + {"side": "buy", "timestamp": base, "cost": 100}, + {"side": "sell", "timestamp": base + 60_000, "cost": 110}, + {"side": "buy", "timestamp": base + 120_000, "cost": 999}, + ] + got = filter_position_lifecycle_fills( + trades, "long", base - 1000, base + 90_000, close_buffer_ms=0 + ) + self.assertEqual(len(got), 2) + + def test_prefer_income_commission(self): + self.assertEqual(merge_commission_prefer_income(0.3, 0.45), 0.45) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_trade_result_lib.py b/tests/test_trade_result_lib.py index 69bbfd0..a285c2e 100644 --- a/tests/test_trade_result_lib.py +++ b/tests/test_trade_result_lib.py @@ -1,30 +1,30 @@ -from trade_result_lib import normalize_result_with_pnl, normalize_display_result, is_winning_pnl - - -def test_stop_loss_with_profit_becomes_trailing_tp(): - assert normalize_result_with_pnl("止损", 4.33) == "移动止盈" - - -def test_manual_close_unchanged_even_with_profit(): - assert normalize_result_with_pnl("手动平仓", 10) == "手动平仓" - - -def test_stop_loss_with_loss_unchanged(): - assert normalize_result_with_pnl("止损", -2.5) == "止损" - - -def test_take_profit_unchanged(): - assert normalize_result_with_pnl("止盈", 5) == "止盈" - - -def test_external_close_becomes_manual_close(): - assert normalize_display_result("外部平仓") == "手动平仓" - assert normalize_result_with_pnl("外部平仓", 2.5) == "手动平仓" - assert normalize_result_with_pnl("外部平仓(自动同步)", -1) == "手动平仓" - - -def test_winning_pnl_positive_only(): - assert is_winning_pnl(2.96) is True - assert is_winning_pnl(0) is False - assert is_winning_pnl(-1.05) is False - assert is_winning_pnl(None) is False +from lib.trade.trade_result_lib import normalize_result_with_pnl, normalize_display_result, is_winning_pnl + + +def test_stop_loss_with_profit_becomes_trailing_tp(): + assert normalize_result_with_pnl("止损", 4.33) == "移动止盈" + + +def test_manual_close_unchanged_even_with_profit(): + assert normalize_result_with_pnl("手动平仓", 10) == "手动平仓" + + +def test_stop_loss_with_loss_unchanged(): + assert normalize_result_with_pnl("止损", -2.5) == "止损" + + +def test_take_profit_unchanged(): + assert normalize_result_with_pnl("止盈", 5) == "止盈" + + +def test_external_close_becomes_manual_close(): + assert normalize_display_result("外部平仓") == "手动平仓" + assert normalize_result_with_pnl("外部平仓", 2.5) == "手动平仓" + assert normalize_result_with_pnl("外部平仓(自动同步)", -1) == "手动平仓" + + +def test_winning_pnl_positive_only(): + assert is_winning_pnl(2.96) is True + assert is_winning_pnl(0) is False + assert is_winning_pnl(-1.05) is False + assert is_winning_pnl(None) is False diff --git a/tests/test_trade_stats_calendar_lib.py b/tests/test_trade_stats_calendar_lib.py index 10a191c..ab33704 100644 --- a/tests/test_trade_stats_calendar_lib.py +++ b/tests/test_trade_stats_calendar_lib.py @@ -1,90 +1,90 @@ -import unittest -from types import SimpleNamespace - -from datetime import datetime - -from trade_stats_calendar_lib import ( - build_initial_stats_calendar, - build_stats_calendar_bootstrap, - build_trade_stats_calendar, -) - - -def _row(**kwargs): - base = { - "monitor_type": "", - "key_signal_type": "", - "exchange_turnover_usdt": None, - "exchange_commission_usdt": None, - } - base.update(kwargs) - return SimpleNamespace(**base) - - -def _matches_all(row, segment_key): - return segment_key == "all" - - -def _matches_manual(row, segment_key): - if segment_key == "all": - return True - if segment_key == "manual": - return (row.monitor_type or "").strip() == "手动" and not (row.key_signal_type or "").strip() - return False - - -class TradeStatsCalendarLibTests(unittest.TestCase): - def test_groups_by_trading_day_and_segment(self): - pnls = [ - (10.0, None, "2026-06-18", _row(monitor_type="手动")), - (-3.0, None, "2026-06-18", _row(monitor_type="手动")), - (5.0, None, "2026-06-19", _row(monitor_type="自动", key_signal_type="箱体突破")), - ] - payload = build_trade_stats_calendar( - pnls, - 2026, - 6, - "manual", - _matches_manual, - reset_hour=8, - ) - self.assertEqual(payload["month"], 6) - self.assertEqual(payload["month_open_count"], 2) - days = payload["days"] - self.assertIn("2026-06-18", days) - self.assertNotIn("2026-06-19", days) - self.assertEqual(days["2026-06-18"]["open_count"], 2) - self.assertAlmostEqual(days["2026-06-18"]["pnl_total"], 7.0) - - def test_invalid_month_raises(self): - with self.assertRaises(ValueError): - build_trade_stats_calendar([], 2026, 13, "all", _matches_all) - - def test_initial_calendar_uses_current_month(self): - pnls = [(2.5, None, "2026-06-20", _row())] - payload = build_initial_stats_calendar( - pnls, - datetime(2026, 6, 26, 12, 0), - _matches_all, - reset_hour=8, - ) - self.assertEqual(payload["year"], 2026) - self.assertEqual(payload["month"], 6) - self.assertEqual(payload["month_open_count"], 1) - self.assertIn("2026-06-20", payload["days"]) - - def test_bootstrap_json_roundtrip(self): - pnls = [(2.5, None, "2026-06-20", _row())] - payload, raw = build_stats_calendar_bootstrap( - pnls, - datetime(2026, 6, 26, 12, 0), - _matches_all, - reset_hour=8, - ) - self.assertIsNotNone(payload) - self.assertIsNotNone(raw) - self.assertIn('"month_open_count":1', raw.replace(" ", "")) - - -if __name__ == "__main__": - unittest.main() +import unittest +from types import SimpleNamespace + +from datetime import datetime + +from lib.trade.trade_stats_calendar_lib import ( + build_initial_stats_calendar, + build_stats_calendar_bootstrap, + build_trade_stats_calendar, +) + + +def _row(**kwargs): + base = { + "monitor_type": "", + "key_signal_type": "", + "exchange_turnover_usdt": None, + "exchange_commission_usdt": None, + } + base.update(kwargs) + return SimpleNamespace(**base) + + +def _matches_all(row, segment_key): + return segment_key == "all" + + +def _matches_manual(row, segment_key): + if segment_key == "all": + return True + if segment_key == "manual": + return (row.monitor_type or "").strip() == "手动" and not (row.key_signal_type or "").strip() + return False + + +class TradeStatsCalendarLibTests(unittest.TestCase): + def test_groups_by_trading_day_and_segment(self): + pnls = [ + (10.0, None, "2026-06-18", _row(monitor_type="手动")), + (-3.0, None, "2026-06-18", _row(monitor_type="手动")), + (5.0, None, "2026-06-19", _row(monitor_type="自动", key_signal_type="箱体突破")), + ] + payload = build_trade_stats_calendar( + pnls, + 2026, + 6, + "manual", + _matches_manual, + reset_hour=8, + ) + self.assertEqual(payload["month"], 6) + self.assertEqual(payload["month_open_count"], 2) + days = payload["days"] + self.assertIn("2026-06-18", days) + self.assertNotIn("2026-06-19", days) + self.assertEqual(days["2026-06-18"]["open_count"], 2) + self.assertAlmostEqual(days["2026-06-18"]["pnl_total"], 7.0) + + def test_invalid_month_raises(self): + with self.assertRaises(ValueError): + build_trade_stats_calendar([], 2026, 13, "all", _matches_all) + + def test_initial_calendar_uses_current_month(self): + pnls = [(2.5, None, "2026-06-20", _row())] + payload = build_initial_stats_calendar( + pnls, + datetime(2026, 6, 26, 12, 0), + _matches_all, + reset_hour=8, + ) + self.assertEqual(payload["year"], 2026) + self.assertEqual(payload["month"], 6) + self.assertEqual(payload["month_open_count"], 1) + self.assertIn("2026-06-20", payload["days"]) + + def test_bootstrap_json_roundtrip(self): + pnls = [(2.5, None, "2026-06-20", _row())] + payload, raw = build_stats_calendar_bootstrap( + pnls, + datetime(2026, 6, 26, 12, 0), + _matches_all, + reset_hour=8, + ) + self.assertIsNotNone(payload) + self.assertIsNotNone(raw) + self.assertIn('"month_open_count":1', raw.replace(" ", "")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_trend_dca_enrich_fills.py b/tests/test_trend_dca_enrich_fills.py index 68f90d5..42a96bb 100644 --- a/tests/test_trend_dca_enrich_fills.py +++ b/tests/test_trend_dca_enrich_fills.py @@ -1,101 +1,101 @@ -"""趋势回调运行中计划:实际成交价重算补仓表与金额盈亏比。""" -from __future__ import annotations - -import json -import sys -import unittest -from pathlib import Path - -ROOT = Path(__file__).resolve().parents[1] -sys.path.insert(0, str(ROOT)) - -from strategy_snapshot_lib import attach_trend_dca_levels # noqa: E402 -from strategy_trend_lib import ( # noqa: E402 - calc_trend_plan_money_metrics, - trend_leg_display_price, -) - - -class TestTrendDcaEnrichFills(unittest.TestCase): - def _base_plan(self, **overrides): - plan = { - "direction": "long", - "stop_loss": 0.329, - "take_profit": 0.476, - "first_order_amount": 115, - "snapshot_available_usdt": 97.98, - "risk_percent": 5, - "contract_size": 1.0, - "grid_prices_json": json.dumps([0.3465, 0.343, 0.3395, 0.336, 0.3325]), - "leg_amounts_json": json.dumps([23, 23, 23, 23, 23]), - "dca_legs": 5, - "first_order_done": 1, - "legs_done": 0, - "avg_entry_price": 0.3537, - "order_amount_open": 115, - "target_order_amount": 230, - "leg_fill_prices_json": json.dumps([0.3537]), - } - plan.update(overrides) - return plan - - def test_header_money_rr_not_price_rr(self): - plan = self._base_plan() - metrics = calc_trend_plan_money_metrics(plan) - self.assertAlmostEqual(metrics["risk_amount_u"], 4.899, places=2) - self.assertIsNotNone(metrics["money_rr"]) - self.assertLess(metrics["money_rr"], 4.0) - - def test_done_dca_uses_actual_fill_price(self): - plan = self._base_plan( - legs_done=1, - avg_entry_price=0.3512, - order_amount_open=138, - leg_fill_prices_json=json.dumps([0.3537, 0.3458]), - ) - enriched = attach_trend_dca_levels(plan) - levels = enriched["dca_levels"] - self.assertEqual(len(levels), 6) - dca1 = levels[1] - self.assertEqual(dca1["status"], "done") - self.assertAlmostEqual(dca1["price"], 0.3458, places=4) - self.assertIsNotNone(dca1["avg_entry"]) - self.assertIsNotNone(dca1["rr"]) - dca2 = levels[2] - self.assertEqual(dca2["status"], "pending") - self.assertAlmostEqual(dca2["price"], 0.343, places=4) - - def test_missing_dca_fills_use_grid_trigger_not_inferred_price(self): - """缺补仓成交价时:触发价用计划网格,末档均价对齐头部,禁止反推离谱成交价。""" - plan = self._base_plan( - legs_done=2, - avg_entry_price=0.3507, - order_amount_open=161, - leg_fill_prices_json=json.dumps([0.3436]), - grid_prices_json=json.dumps([0.343, 0.343, 0.3395, 0.336, 0.3325]), - ) - enriched = attach_trend_dca_levels(plan) - levels = enriched["dca_levels"] - dca1 = levels[1] - dca2 = levels[2] - self.assertEqual(dca1["status"], "done") - self.assertAlmostEqual(dca1["price"], 0.343, places=4) - self.assertEqual(dca2["status"], "done") - self.assertAlmostEqual(dca2["price"], 0.343, places=4) - self.assertAlmostEqual(dca2["avg_entry"], 0.3507, places=4) - self.assertLess(dca2["price"], 0.36) - - def test_display_price_never_infers_from_target_avg(self): - """四所共用:缺记录时只用网格,不因均价反推离谱触发价。""" - plan = self._base_plan( - legs_done=2, - avg_entry_price=0.3507, - leg_fill_prices_json=json.dumps([0.3436]), - grid_prices_json=json.dumps([0.343, 0.343, 0.3395, 0.336, 0.3325]), - ) - self.assertAlmostEqual(trend_leg_display_price(plan, 2), 0.343, places=4) - self.assertLess(trend_leg_display_price(plan, 2), 0.36) - - -if __name__ == "__main__": - unittest.main() +"""趋势回调运行中计划:实际成交价重算补仓表与金额盈亏比。""" +from __future__ import annotations + +import json +import sys +import unittest +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT)) + +from lib.strategy.strategy_snapshot_lib import attach_trend_dca_levels # noqa: E402 +from lib.strategy.strategy_trend_lib import ( # noqa: E402 + calc_trend_plan_money_metrics, + trend_leg_display_price, +) + + +class TestTrendDcaEnrichFills(unittest.TestCase): + def _base_plan(self, **overrides): + plan = { + "direction": "long", + "stop_loss": 0.329, + "take_profit": 0.476, + "first_order_amount": 115, + "snapshot_available_usdt": 97.98, + "risk_percent": 5, + "contract_size": 1.0, + "grid_prices_json": json.dumps([0.3465, 0.343, 0.3395, 0.336, 0.3325]), + "leg_amounts_json": json.dumps([23, 23, 23, 23, 23]), + "dca_legs": 5, + "first_order_done": 1, + "legs_done": 0, + "avg_entry_price": 0.3537, + "order_amount_open": 115, + "target_order_amount": 230, + "leg_fill_prices_json": json.dumps([0.3537]), + } + plan.update(overrides) + return plan + + def test_header_money_rr_not_price_rr(self): + plan = self._base_plan() + metrics = calc_trend_plan_money_metrics(plan) + self.assertAlmostEqual(metrics["risk_amount_u"], 4.899, places=2) + self.assertIsNotNone(metrics["money_rr"]) + self.assertLess(metrics["money_rr"], 4.0) + + def test_done_dca_uses_actual_fill_price(self): + plan = self._base_plan( + legs_done=1, + avg_entry_price=0.3512, + order_amount_open=138, + leg_fill_prices_json=json.dumps([0.3537, 0.3458]), + ) + enriched = attach_trend_dca_levels(plan) + levels = enriched["dca_levels"] + self.assertEqual(len(levels), 6) + dca1 = levels[1] + self.assertEqual(dca1["status"], "done") + self.assertAlmostEqual(dca1["price"], 0.3458, places=4) + self.assertIsNotNone(dca1["avg_entry"]) + self.assertIsNotNone(dca1["rr"]) + dca2 = levels[2] + self.assertEqual(dca2["status"], "pending") + self.assertAlmostEqual(dca2["price"], 0.343, places=4) + + def test_missing_dca_fills_use_grid_trigger_not_inferred_price(self): + """缺补仓成交价时:触发价用计划网格,末档均价对齐头部,禁止反推离谱成交价。""" + plan = self._base_plan( + legs_done=2, + avg_entry_price=0.3507, + order_amount_open=161, + leg_fill_prices_json=json.dumps([0.3436]), + grid_prices_json=json.dumps([0.343, 0.343, 0.3395, 0.336, 0.3325]), + ) + enriched = attach_trend_dca_levels(plan) + levels = enriched["dca_levels"] + dca1 = levels[1] + dca2 = levels[2] + self.assertEqual(dca1["status"], "done") + self.assertAlmostEqual(dca1["price"], 0.343, places=4) + self.assertEqual(dca2["status"], "done") + self.assertAlmostEqual(dca2["price"], 0.343, places=4) + self.assertAlmostEqual(dca2["avg_entry"], 0.3507, places=4) + self.assertLess(dca2["price"], 0.36) + + def test_display_price_never_infers_from_target_avg(self): + """四所共用:缺记录时只用网格,不因均价反推离谱触发价。""" + plan = self._base_plan( + legs_done=2, + avg_entry_price=0.3507, + leg_fill_prices_json=json.dumps([0.3436]), + grid_prices_json=json.dumps([0.343, 0.343, 0.3395, 0.336, 0.3325]), + ) + self.assertAlmostEqual(trend_leg_display_price(plan, 2), 0.343, places=4) + self.assertLess(trend_leg_display_price(plan, 2), 0.36) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_trend_dca_pnl.py b/tests/test_trend_dca_pnl.py index 54ee8b9..5832661 100644 --- a/tests/test_trend_dca_pnl.py +++ b/tests/test_trend_dca_pnl.py @@ -1,43 +1,43 @@ -"""趋势回调:补仓触达与有效保证金估算。""" -from strategy_trend_lib import trend_dca_level_reached, trend_effective_margin_capital - - -def test_trend_dca_short_monotonic_up_fills_missed_legs(): - """做空价升:旧逻辑需 last