refactor: 将共用代码迁入 lib/ 模块化目录

统一 strategy、key_monitor、trade、hub 等共用库到 lib/ 子包,并补充 lib-structure 文档,便于四所与中控维护。

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-07-02 16:23:09 +08:00
parent 4742a0bb9d
commit 5797d49d8a
190 changed files with 27946 additions and 27499 deletions
File diff suppressed because it is too large Load Diff
+63 -63
View File
@@ -1,63 +1,63 @@
"""AI 复盘 journal 文本格式化(四所共用)。"""
from __future__ import annotations
import sqlite3
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from ai_review_lib import journal_row_lines_for_ai # noqa: E402
class TestAiReviewLib(unittest.TestCase):
def test_journal_row_includes_expect_and_actual_rr(self):
text = journal_row_lines_for_ai(
1,
{
"coin": "HYPE",
"tf": "5m",
"pnl": "10.73",
"real_rr": "2.1354",
"expect_rr": "-",
"entry_reason": "趋势回调",
"exit_reason": "移动止盈",
"hold_duration": "1天 3小时",
"mood_issues": "",
"post_breakeven_stare": "",
"new_trade_while_occupied": "",
"note": "测试备注",
},
)
self.assertIn("实际RR:2.1354", text)
self.assertIn("预期RR:-", text)
self.assertIn("开仓逻辑:趋势回调", text)
self.assertIn("备注:测试备注", text)
self.assertNotIn("开仓类型", text)
def test_journal_row_accepts_sqlite_row(self):
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
conn.execute(
"""CREATE TABLE journal_entries (
coin TEXT, tf TEXT, pnl TEXT, real_rr TEXT, expect_rr TEXT,
entry_reason TEXT, exit_reason TEXT, hold_duration TEXT,
mood_issues TEXT, mood_score INTEGER, note TEXT
)"""
)
conn.execute(
"""INSERT INTO journal_entries VALUES (?,?,?,?,?,?,?,?,?,?,?)""",
("BTC", "15m", "5", "1.2", "2.0", "突破", "止盈", "2小时", "", None, ""),
)
row = conn.execute("SELECT * FROM journal_entries").fetchone()
conn.close()
text = journal_row_lines_for_ai(1, row)
self.assertIn("BTC 15m", text)
self.assertIn("实际RR:1.2", text)
self.assertIn("开仓逻辑:突破", text)
if __name__ == "__main__":
unittest.main()
"""AI 复盘 journal 文本格式化(四所共用)。"""
from __future__ import annotations
import sqlite3
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from lib.ai.ai_review_lib import journal_row_lines_for_ai # noqa: E402
class TestAiReviewLib(unittest.TestCase):
def test_journal_row_includes_expect_and_actual_rr(self):
text = journal_row_lines_for_ai(
1,
{
"coin": "HYPE",
"tf": "5m",
"pnl": "10.73",
"real_rr": "2.1354",
"expect_rr": "-",
"entry_reason": "趋势回调",
"exit_reason": "移动止盈",
"hold_duration": "1天 3小时",
"mood_issues": "",
"post_breakeven_stare": "",
"new_trade_while_occupied": "",
"note": "测试备注",
},
)
self.assertIn("实际RR:2.1354", text)
self.assertIn("预期RR:-", text)
self.assertIn("开仓逻辑:趋势回调", text)
self.assertIn("备注:测试备注", text)
self.assertNotIn("开仓类型", text)
def test_journal_row_accepts_sqlite_row(self):
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
conn.execute(
"""CREATE TABLE journal_entries (
coin TEXT, tf TEXT, pnl TEXT, real_rr TEXT, expect_rr TEXT,
entry_reason TEXT, exit_reason TEXT, hold_duration TEXT,
mood_issues TEXT, mood_score INTEGER, note TEXT
)"""
)
conn.execute(
"""INSERT INTO journal_entries VALUES (?,?,?,?,?,?,?,?,?,?,?)""",
("BTC", "15m", "5", "1.2", "2.0", "突破", "止盈", "2小时", "", None, ""),
)
row = conn.execute("SELECT * FROM journal_entries").fetchone()
conn.close()
text = journal_row_lines_for_ai(1, row)
self.assertIn("BTC 15m", text)
self.assertIn("实际RR:1.2", text)
self.assertIn("开仓逻辑:突破", text)
if __name__ == "__main__":
unittest.main()
+60 -60
View File
@@ -1,60 +1,60 @@
import sqlite3
import tempfile
import unittest
from datetime import datetime
from pathlib import Path
from zoneinfo import ZoneInfo
from hub_symbol_archive_lib import init_db, list_archive_calendar, upsert_trades_cache, upsert_trade_overlay
def _bj_ms(y, m, d, hh, mm):
dt = datetime(y, m, d, hh, mm, 0, tzinfo=ZoneInfo("Asia/Shanghai"))
return int(dt.timestamp() * 1000)
class ArchiveCalendarTests(unittest.TestCase):
def test_calendar_groups_by_trading_day_and_sick(self):
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "arch.db"
init_db(db)
upsert_trades_cache(
"binance",
[
{
"id": 1,
"symbol": "BTC/USDT",
"direction": "long",
"result": "止盈",
"pnl_amount": 10.0,
"opened_at": "2026-06-18 09:00:00",
"closed_at": "2026-06-18 10:00:00",
"closed_at_ms": _bj_ms(2026, 6, 18, 10, 0),
"exchange_turnover_usdt": 2000.0,
"exchange_commission_usdt": 0.8,
},
{
"id": 2,
"symbol": "ETH/USDT",
"direction": "short",
"result": "止损",
"pnl_amount": -5.0,
"opened_at": "2026-06-18 14:00:00",
"closed_at": "2026-06-18 15:00:00",
"closed_at_ms": _bj_ms(2026, 6, 18, 15, 0),
},
],
db_path=db,
)
upsert_trade_overlay("binance", 2, behavior_tag="sick", db_path=db)
payload = list_archive_calendar(2026, 6, db_path=db)
self.assertEqual(payload["month"], 6)
days = payload["days"]
self.assertTrue(days)
sick_days = [d for d in days.values() if d.get("has_sick")]
self.assertTrue(sick_days)
self.assertGreaterEqual(payload["month_open_count"], 2)
if __name__ == "__main__":
unittest.main()
import sqlite3
import tempfile
import unittest
from datetime import datetime
from pathlib import Path
from zoneinfo import ZoneInfo
from lib.hub.hub_symbol_archive_lib import init_db, list_archive_calendar, upsert_trades_cache, upsert_trade_overlay
def _bj_ms(y, m, d, hh, mm):
dt = datetime(y, m, d, hh, mm, 0, tzinfo=ZoneInfo("Asia/Shanghai"))
return int(dt.timestamp() * 1000)
class ArchiveCalendarTests(unittest.TestCase):
def test_calendar_groups_by_trading_day_and_sick(self):
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "arch.db"
init_db(db)
upsert_trades_cache(
"binance",
[
{
"id": 1,
"symbol": "BTC/USDT",
"direction": "long",
"result": "止盈",
"pnl_amount": 10.0,
"opened_at": "2026-06-18 09:00:00",
"closed_at": "2026-06-18 10:00:00",
"closed_at_ms": _bj_ms(2026, 6, 18, 10, 0),
"exchange_turnover_usdt": 2000.0,
"exchange_commission_usdt": 0.8,
},
{
"id": 2,
"symbol": "ETH/USDT",
"direction": "short",
"result": "止损",
"pnl_amount": -5.0,
"opened_at": "2026-06-18 14:00:00",
"closed_at": "2026-06-18 15:00:00",
"closed_at_ms": _bj_ms(2026, 6, 18, 15, 0),
},
],
db_path=db,
)
upsert_trade_overlay("binance", 2, behavior_tag="sick", db_path=db)
payload = list_archive_calendar(2026, 6, db_path=db)
self.assertEqual(payload["month"], 6)
days = payload["days"]
self.assertTrue(days)
sick_days = [d for d in days.values() if d.get("has_sick")]
self.assertTrue(sick_days)
self.assertGreaterEqual(payload["month_open_count"], 2)
if __name__ == "__main__":
unittest.main()
+90 -90
View File
@@ -1,90 +1,90 @@
import unittest
from daily_open_limit_lib import (
build_daily_open_alert_prompt,
can_trade_new_open,
check_daily_open_hard_limit,
count_opens_for_trading_day,
daily_open_hard_limit_blocks,
format_daily_open_counter_line,
hard_limit_block_reason,
load_daily_open_limits_from_env,
parse_daily_open_hard_limit,
should_send_daily_open_alert,
)
class _FakeConn:
def __init__(self, count: int):
self._count = count
def execute(self, _sql, _params):
return self
def fetchone(self):
return (self._count,)
class DailyOpenLimitLibTests(unittest.TestCase):
def test_parse_hard_limit_zero_disables(self):
self.assertEqual(parse_daily_open_hard_limit("0"), 0)
self.assertEqual(parse_daily_open_hard_limit(None, default=0), 0)
def test_load_from_env(self):
alert, hard = load_daily_open_limits_from_env(
{"DAILY_OPEN_ALERT_THRESHOLD": "3", "DAILY_OPEN_HARD_LIMIT": "8"}
)
self.assertEqual(alert, 3)
self.assertEqual(hard, 8)
def test_hard_limit_blocks(self):
self.assertFalse(daily_open_hard_limit_blocks(4, 0))
self.assertFalse(daily_open_hard_limit_blocks(4, 5))
self.assertTrue(daily_open_hard_limit_blocks(5, 5))
def test_check_daily_open_hard_limit(self):
conn = _FakeConn(5)
ok, reason, n = check_daily_open_hard_limit(conn, "2026-06-07", 5, 8)
self.assertFalse(ok)
self.assertEqual(n, 5)
self.assertIn("已达上限", reason)
self.assertIn("8:00", reason)
def test_count_opens(self):
self.assertEqual(count_opens_for_trading_day(_FakeConn(3), "2026-06-07"), 3)
def test_can_trade_new_open(self):
self.assertTrue(
can_trade_new_open(
time_allows=True,
active_count=0,
max_active_positions=1,
opens_today=2,
hard_limit=5,
)
)
self.assertFalse(
can_trade_new_open(
time_allows=True,
active_count=0,
max_active_positions=1,
opens_today=5,
hard_limit=5,
)
)
def test_alert_crossing(self):
self.assertTrue(should_send_daily_open_alert(4, 5, 5))
self.assertFalse(should_send_daily_open_alert(5, 6, 5))
def test_prompt_includes_hard_limit(self):
txt = build_daily_open_alert_prompt("2026-06-07", 5, 5, hard_limit=8)
self.assertIn("硬上限 8", txt)
def test_counter_line(self):
line = format_daily_open_counter_line(3, 5, 8)
self.assertIn("3 / 硬上限 8", line)
if __name__ == "__main__":
unittest.main()
import unittest
from lib.trade.daily_open_limit_lib import (
build_daily_open_alert_prompt,
can_trade_new_open,
check_daily_open_hard_limit,
count_opens_for_trading_day,
daily_open_hard_limit_blocks,
format_daily_open_counter_line,
hard_limit_block_reason,
load_daily_open_limits_from_env,
parse_daily_open_hard_limit,
should_send_daily_open_alert,
)
class _FakeConn:
def __init__(self, count: int):
self._count = count
def execute(self, _sql, _params):
return self
def fetchone(self):
return (self._count,)
class DailyOpenLimitLibTests(unittest.TestCase):
def test_parse_hard_limit_zero_disables(self):
self.assertEqual(parse_daily_open_hard_limit("0"), 0)
self.assertEqual(parse_daily_open_hard_limit(None, default=0), 0)
def test_load_from_env(self):
alert, hard = load_daily_open_limits_from_env(
{"DAILY_OPEN_ALERT_THRESHOLD": "3", "DAILY_OPEN_HARD_LIMIT": "8"}
)
self.assertEqual(alert, 3)
self.assertEqual(hard, 8)
def test_hard_limit_blocks(self):
self.assertFalse(daily_open_hard_limit_blocks(4, 0))
self.assertFalse(daily_open_hard_limit_blocks(4, 5))
self.assertTrue(daily_open_hard_limit_blocks(5, 5))
def test_check_daily_open_hard_limit(self):
conn = _FakeConn(5)
ok, reason, n = check_daily_open_hard_limit(conn, "2026-06-07", 5, 8)
self.assertFalse(ok)
self.assertEqual(n, 5)
self.assertIn("已达上限", reason)
self.assertIn("8:00", reason)
def test_count_opens(self):
self.assertEqual(count_opens_for_trading_day(_FakeConn(3), "2026-06-07"), 3)
def test_can_trade_new_open(self):
self.assertTrue(
can_trade_new_open(
time_allows=True,
active_count=0,
max_active_positions=1,
opens_today=2,
hard_limit=5,
)
)
self.assertFalse(
can_trade_new_open(
time_allows=True,
active_count=0,
max_active_positions=1,
opens_today=5,
hard_limit=5,
)
)
def test_alert_crossing(self):
self.assertTrue(should_send_daily_open_alert(4, 5, 5))
self.assertFalse(should_send_daily_open_alert(5, 6, 5))
def test_prompt_includes_hard_limit(self):
txt = build_daily_open_alert_prompt("2026-06-07", 5, 5, hard_limit=8)
self.assertIn("硬上限 8", txt)
def test_counter_line(self):
line = format_daily_open_counter_line(3, 5, 8)
self.assertIn("3 / 硬上限 8", line)
if __name__ == "__main__":
unittest.main()
+76 -76
View File
@@ -1,76 +1,76 @@
import unittest
from datetime import datetime, timedelta
from false_breakout_key_monitor_lib import (
FALSE_BREAKOUT_MONITOR_TYPE,
calc_false_breakout_plan,
false_breakout_gate_preview,
is_false_breakout_expired,
key_price_from_row,
normalize_false_breakout_symbol,
storage_bounds_from_key_price,
)
class FalseBreakoutKeyMonitorLibTests(unittest.TestCase):
def test_normalize_symbol(self):
self.assertEqual(normalize_false_breakout_symbol("btc"), "BTC/USDT")
self.assertEqual(normalize_false_breakout_symbol("ETH/USDT"), "ETH/USDT")
self.assertIsNone(normalize_false_breakout_symbol("SOL"))
def test_short_plan(self):
plan = calc_false_breakout_plan("short", 100000)
self.assertIsNotNone(plan)
entry, sl, tp = plan
self.assertAlmostEqual(entry, 100100.0)
self.assertAlmostEqual(sl, 100600.5)
self.assertAlmostEqual(tp, 99349.25)
def test_long_plan(self):
plan = calc_false_breakout_plan("long", 100000)
self.assertIsNotNone(plan)
entry, sl, tp = plan
self.assertAlmostEqual(entry, 99900.0)
self.assertAlmostEqual(sl, 99400.5)
self.assertAlmostEqual(tp, 100649.25)
def test_storage_bounds(self):
up, low = storage_bounds_from_key_price("short", 100000)
self.assertGreater(up, low)
self.assertAlmostEqual(up, 100000.0)
self.assertAlmostEqual(low, 99990.0)
up, low = storage_bounds_from_key_price("long", 100000)
self.assertGreater(up, low)
self.assertAlmostEqual(low, 100000.0)
self.assertAlmostEqual(up, 100010.0)
def test_key_price_from_row(self):
self.assertEqual(key_price_from_row("short", 100100, 100000), 100100)
self.assertEqual(key_price_from_row("long", 100100, 100000), 100000)
def test_expiry(self):
now = datetime(2026, 6, 9, 12, 0, 0)
created = "2026-06-08 12:00:00"
self.assertTrue(is_false_breakout_expired(created, now))
self.assertFalse(is_false_breakout_expired(created, now - timedelta(hours=1)))
def test_monitor_type_constant(self):
self.assertEqual(FALSE_BREAKOUT_MONITOR_TYPE, "假突破")
def test_gate_preview_not_box_gate(self):
now = datetime(2026, 6, 7, 12, 0, 0)
prev = false_breakout_gate_preview(
entry_display="1635.0",
limit_order_id="oid-1",
created_at="2026-06-07 10:00:00",
now=now,
)
self.assertIn("假突破", prev["summary"])
self.assertIn("等待成交", prev["summary"])
self.assertNotIn("量:", prev["summary"])
self.assertIn("限价单:oid-1", prev["metrics"])
self.assertTrue(prev["gate_ok"])
if __name__ == "__main__":
unittest.main()
import unittest
from datetime import datetime, timedelta
from lib.key_monitor.false_breakout_key_monitor_lib import (
FALSE_BREAKOUT_MONITOR_TYPE,
calc_false_breakout_plan,
false_breakout_gate_preview,
is_false_breakout_expired,
key_price_from_row,
normalize_false_breakout_symbol,
storage_bounds_from_key_price,
)
class FalseBreakoutKeyMonitorLibTests(unittest.TestCase):
def test_normalize_symbol(self):
self.assertEqual(normalize_false_breakout_symbol("btc"), "BTC/USDT")
self.assertEqual(normalize_false_breakout_symbol("ETH/USDT"), "ETH/USDT")
self.assertIsNone(normalize_false_breakout_symbol("SOL"))
def test_short_plan(self):
plan = calc_false_breakout_plan("short", 100000)
self.assertIsNotNone(plan)
entry, sl, tp = plan
self.assertAlmostEqual(entry, 100100.0)
self.assertAlmostEqual(sl, 100600.5)
self.assertAlmostEqual(tp, 99349.25)
def test_long_plan(self):
plan = calc_false_breakout_plan("long", 100000)
self.assertIsNotNone(plan)
entry, sl, tp = plan
self.assertAlmostEqual(entry, 99900.0)
self.assertAlmostEqual(sl, 99400.5)
self.assertAlmostEqual(tp, 100649.25)
def test_storage_bounds(self):
up, low = storage_bounds_from_key_price("short", 100000)
self.assertGreater(up, low)
self.assertAlmostEqual(up, 100000.0)
self.assertAlmostEqual(low, 99990.0)
up, low = storage_bounds_from_key_price("long", 100000)
self.assertGreater(up, low)
self.assertAlmostEqual(low, 100000.0)
self.assertAlmostEqual(up, 100010.0)
def test_key_price_from_row(self):
self.assertEqual(key_price_from_row("short", 100100, 100000), 100100)
self.assertEqual(key_price_from_row("long", 100100, 100000), 100000)
def test_expiry(self):
now = datetime(2026, 6, 9, 12, 0, 0)
created = "2026-06-08 12:00:00"
self.assertTrue(is_false_breakout_expired(created, now))
self.assertFalse(is_false_breakout_expired(created, now - timedelta(hours=1)))
def test_monitor_type_constant(self):
self.assertEqual(FALSE_BREAKOUT_MONITOR_TYPE, "假突破")
def test_gate_preview_not_box_gate(self):
now = datetime(2026, 6, 7, 12, 0, 0)
prev = false_breakout_gate_preview(
entry_display="1635.0",
limit_order_id="oid-1",
created_at="2026-06-07 10:00:00",
now=now,
)
self.assertIn("假突破", prev["summary"])
self.assertIn("等待成交", prev["summary"])
self.assertNotIn("量:", prev["summary"])
self.assertIn("限价单:oid-1", prev["metrics"])
self.assertTrue(prev["gate_ok"])
if __name__ == "__main__":
unittest.main()
+26 -26
View File
@@ -1,26 +1,26 @@
from gate_position_history_lib import pick_gate_position_close, unified_symbol_for_match
def test_unified_symbol_strips_settle_suffix():
assert unified_symbol_for_match("BTC/USDT:USDT") == "BTC/USDT"
def test_pick_gate_position_close_matches_symbol_side_and_time():
hist = [
{
"symbol_u": "SOL/USDT",
"side": "short",
"close_ms": 1_700_000_000_000,
"open_ms": 1_699_999_000_000,
"pnl": -1.25,
"sync_key": "SOL_USDT|1|short",
}
]
hit = pick_gate_position_close(
hist,
"SOL/USDT:USDT",
"short",
opened_at_ms=1_699_999_500_000,
)
assert hit is not None
assert hit["pnl"] == -1.25
from lib.exchange.gate_position_history_lib import pick_gate_position_close, unified_symbol_for_match
def test_unified_symbol_strips_settle_suffix():
assert unified_symbol_for_match("BTC/USDT:USDT") == "BTC/USDT"
def test_pick_gate_position_close_matches_symbol_side_and_time():
hist = [
{
"symbol_u": "SOL/USDT",
"side": "short",
"close_ms": 1_700_000_000_000,
"open_ms": 1_699_999_000_000,
"pnl": -1.25,
"sync_key": "SOL_USDT|1|short",
}
]
hit = pick_gate_position_close(
hist,
"SOL/USDT:USDT",
"short",
opened_at_ms=1_699_999_500_000,
)
assert hit is not None
assert hit["pnl"] == -1.25
+44 -44
View File
@@ -1,44 +1,44 @@
"""gate_transfer_lib 单元测试。"""
from __future__ import annotations
import sqlite3
import unittest
from gate_transfer_lib import count_auto_transfer_blockers
class GateTransferLibTest(unittest.TestCase):
def test_counts_order_monitors_first(self):
conn = sqlite3.connect(":memory:")
conn.execute("CREATE TABLE order_monitors (status TEXT)")
conn.execute("CREATE TABLE trend_pullback_plans (status TEXT, first_order_done INTEGER)")
conn.execute("INSERT INTO order_monitors VALUES ('active')")
conn.execute("INSERT INTO trend_pullback_plans VALUES ('active', 1)")
conn.commit()
n = count_auto_transfer_blockers(conn, count_order_monitors=lambda c: 1)
self.assertEqual(n, 1)
conn.close()
def test_counts_trend_plan_when_no_order_monitors(self):
conn = sqlite3.connect(":memory:")
conn.execute("CREATE TABLE order_monitors (status TEXT)")
conn.execute("CREATE TABLE trend_pullback_plans (status TEXT, first_order_done INTEGER)")
conn.execute("INSERT INTO trend_pullback_plans VALUES ('active', 1)")
conn.commit()
n = count_auto_transfer_blockers(conn, count_order_monitors=lambda c: 0)
self.assertEqual(n, 1)
conn.close()
def test_ignores_trend_plan_without_first_order(self):
conn = sqlite3.connect(":memory:")
conn.execute("CREATE TABLE order_monitors (status TEXT)")
conn.execute("CREATE TABLE trend_pullback_plans (status TEXT, first_order_done INTEGER)")
conn.execute("INSERT INTO trend_pullback_plans VALUES ('active', 0)")
conn.commit()
n = count_auto_transfer_blockers(conn, count_order_monitors=lambda c: 0)
self.assertEqual(n, 0)
conn.close()
if __name__ == "__main__":
unittest.main()
"""gate_transfer_lib 单元测试。"""
from __future__ import annotations
import sqlite3
import unittest
from lib.exchange.gate_transfer_lib import count_auto_transfer_blockers
class GateTransferLibTest(unittest.TestCase):
def test_counts_order_monitors_first(self):
conn = sqlite3.connect(":memory:")
conn.execute("CREATE TABLE order_monitors (status TEXT)")
conn.execute("CREATE TABLE trend_pullback_plans (status TEXT, first_order_done INTEGER)")
conn.execute("INSERT INTO order_monitors VALUES ('active')")
conn.execute("INSERT INTO trend_pullback_plans VALUES ('active', 1)")
conn.commit()
n = count_auto_transfer_blockers(conn, count_order_monitors=lambda c: 1)
self.assertEqual(n, 1)
conn.close()
def test_counts_trend_plan_when_no_order_monitors(self):
conn = sqlite3.connect(":memory:")
conn.execute("CREATE TABLE order_monitors (status TEXT)")
conn.execute("CREATE TABLE trend_pullback_plans (status TEXT, first_order_done INTEGER)")
conn.execute("INSERT INTO trend_pullback_plans VALUES ('active', 1)")
conn.commit()
n = count_auto_transfer_blockers(conn, count_order_monitors=lambda c: 0)
self.assertEqual(n, 1)
conn.close()
def test_ignores_trend_plan_without_first_order(self):
conn = sqlite3.connect(":memory:")
conn.execute("CREATE TABLE order_monitors (status TEXT)")
conn.execute("CREATE TABLE trend_pullback_plans (status TEXT, first_order_done INTEGER)")
conn.execute("INSERT INTO trend_pullback_plans VALUES ('active', 0)")
conn.commit()
n = count_auto_transfer_blockers(conn, count_order_monitors=lambda c: 0)
self.assertEqual(n, 0)
conn.close()
if __name__ == "__main__":
unittest.main()
+94 -94
View File
@@ -1,94 +1,94 @@
"""子代理持仓:四所标记价字段统一解析。"""
from __future__ import annotations
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "manual_trading_hub"))
from agent import _position_mark_price, _ticker_mark_price # noqa: E402
sys.path.insert(0, str(ROOT))
from hub_position_metrics import ( # noqa: E402
enrich_ccxt_position_metrics_out,
estimate_linear_swap_upnl_usdt,
parse_position_unrealized_pnl,
resolve_position_display_upnl,
)
class TestHubAgentMarkPrice(unittest.TestCase):
def test_binance_mark_price(self):
px = _position_mark_price({"markPrice": 65880.1, "info": {}})
self.assertAlmostEqual(px, 65880.1)
def test_okx_mark_px(self):
px = _position_mark_price({"info": {"markPx": "72.85"}})
self.assertAlmostEqual(px, 72.85)
def test_gate_info_mark(self):
px = _position_mark_price({"info": {"mark_price": "0.2241"}})
self.assertAlmostEqual(px, 0.2241)
def test_missing_returns_none(self):
self.assertIsNone(_position_mark_price({"info": {}}))
def test_infer_from_notional_and_contracts(self):
p = {"notional": 1000, "contracts": 10, "info": {}}
px = _position_mark_price(p)
self.assertAlmostEqual(px, 100.0)
def test_ticker_fallback(self):
class _Ex:
def fetch_ticker(self, sym):
return {"mark": 99.5, "info": {}}
self.assertAlmostEqual(_ticker_mark_price(_Ex(), "BTC/USDT:USDT"), 99.5)
def test_gate_unrealised_pnl_in_info(self):
pnl = parse_position_unrealized_pnl(
{"info": {"unrealised_pnl": "6.81"}, "unrealizedPnl": None}
)
self.assertAlmostEqual(pnl, 6.81)
def test_okx_upl_signed(self):
pnl = parse_position_unrealized_pnl(
{"info": {"upl": "-2.15"}, "unrealizedPnl": None}
)
self.assertAlmostEqual(pnl, -2.15)
def test_enrich_aligns_short_gate_metrics(self):
pos = {
"side": "short",
"contracts": 11,
"entryPrice": 73.187,
"markPrice": 66.038,
"info": {"unrealised_pnl": "7.86"},
}
out = {"unrealized_pnl": 7.86, "mark_price": 66.038}
enrich_ccxt_position_metrics_out(pos, out, contract_size=1.0, funds_decimals=2)
self.assertGreater(out["unrealized_pnl"], 70.0)
def test_estimate_short_hype_contract_size(self):
upnl = estimate_linear_swap_upnl_usdt(
"short", 73.187, 66.038, 11, 0.1
)
self.assertAlmostEqual(upnl, 7.86, places=1)
def test_resolve_prefers_computed_when_exchange_off(self):
shown = resolve_position_display_upnl(
"short", 73.187, 66.038, 11, 1.0, 7.86
)
self.assertAlmostEqual(shown, 78.64, places=1)
def test_resolve_keeps_exchange_when_aligned(self):
shown = resolve_position_display_upnl(
"short", 73.187, 66.038, 11, 0.1, 7.86
)
self.assertAlmostEqual(shown, 7.86, places=2)
if __name__ == "__main__":
unittest.main()
"""子代理持仓:四所标记价字段统一解析。"""
from __future__ import annotations
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "manual_trading_hub"))
from agent import _position_mark_price, _ticker_mark_price # noqa: E402
sys.path.insert(0, str(ROOT))
from lib.hub.hub_position_metrics import ( # noqa: E402
enrich_ccxt_position_metrics_out,
estimate_linear_swap_upnl_usdt,
parse_position_unrealized_pnl,
resolve_position_display_upnl,
)
class TestHubAgentMarkPrice(unittest.TestCase):
def test_binance_mark_price(self):
px = _position_mark_price({"markPrice": 65880.1, "info": {}})
self.assertAlmostEqual(px, 65880.1)
def test_okx_mark_px(self):
px = _position_mark_price({"info": {"markPx": "72.85"}})
self.assertAlmostEqual(px, 72.85)
def test_gate_info_mark(self):
px = _position_mark_price({"info": {"mark_price": "0.2241"}})
self.assertAlmostEqual(px, 0.2241)
def test_missing_returns_none(self):
self.assertIsNone(_position_mark_price({"info": {}}))
def test_infer_from_notional_and_contracts(self):
p = {"notional": 1000, "contracts": 10, "info": {}}
px = _position_mark_price(p)
self.assertAlmostEqual(px, 100.0)
def test_ticker_fallback(self):
class _Ex:
def fetch_ticker(self, sym):
return {"mark": 99.5, "info": {}}
self.assertAlmostEqual(_ticker_mark_price(_Ex(), "BTC/USDT:USDT"), 99.5)
def test_gate_unrealised_pnl_in_info(self):
pnl = parse_position_unrealized_pnl(
{"info": {"unrealised_pnl": "6.81"}, "unrealizedPnl": None}
)
self.assertAlmostEqual(pnl, 6.81)
def test_okx_upl_signed(self):
pnl = parse_position_unrealized_pnl(
{"info": {"upl": "-2.15"}, "unrealizedPnl": None}
)
self.assertAlmostEqual(pnl, -2.15)
def test_enrich_aligns_short_gate_metrics(self):
pos = {
"side": "short",
"contracts": 11,
"entryPrice": 73.187,
"markPrice": 66.038,
"info": {"unrealised_pnl": "7.86"},
}
out = {"unrealized_pnl": 7.86, "mark_price": 66.038}
enrich_ccxt_position_metrics_out(pos, out, contract_size=1.0, funds_decimals=2)
self.assertGreater(out["unrealized_pnl"], 70.0)
def test_estimate_short_hype_contract_size(self):
upnl = estimate_linear_swap_upnl_usdt(
"short", 73.187, 66.038, 11, 0.1
)
self.assertAlmostEqual(upnl, 7.86, places=1)
def test_resolve_prefers_computed_when_exchange_off(self):
shown = resolve_position_display_upnl(
"short", 73.187, 66.038, 11, 1.0, 7.86
)
self.assertAlmostEqual(shown, 78.64, places=1)
def test_resolve_keeps_exchange_when_aligned(self):
shown = resolve_position_display_upnl(
"short", 73.187, 66.038, 11, 0.1, 7.86
)
self.assertAlmostEqual(shown, 7.86, places=2)
if __name__ == "__main__":
unittest.main()
+162 -162
View File
@@ -1,162 +1,162 @@
"""hub_calculator_lib 测算逻辑。"""
import unittest
from unittest.mock import patch
from hub_calculator_lib import (
calc_initial_roll_qty,
calc_roll_calculator,
calc_trend_calculator,
solve_add_amount_for_total_risk,
)
MOCK_MARKET = {
"exchange_id": "0",
"exchange_key": "binance",
"exchange_name": "币安 · crypto_monitor_binance",
"exchange_label": "币安 · crypto_monitor_binance",
"base": "ETH",
"exchange_symbol": "ETH/USDT:USDT",
"display_symbol": "ETH/USDT",
"contract_size": 1.0,
"price_tick": 0.01,
"price_decimals": 2,
"amount_decimals": 3,
"min_amount": 0.001,
}
def _mock_resolve(_exchange="binance", _base="ETH"):
return MOCK_MARKET, lambda amount: round(float(amount), 3), None
class HubCalculatorLibTests(unittest.TestCase):
@patch("hub_calculator_lib._resolve_market", return_value=_mock_resolve())
def test_trend_calculator_long_basic(self, _mock):
data, err = calc_trend_calculator(
direction="long",
capital_usdt=1000,
risk_percent=5,
leverage=5,
entry_price=100,
stop_loss=95,
add_upper=110,
take_profit=120,
dca_legs=3,
exchange_id="0",
base="ETH",
)
self.assertIsNone(err)
self.assertIsNotNone(data)
assert data is not None
self.assertEqual(data["risk_budget_u"], 50.0)
self.assertGreaterEqual(len(data["rows"]), 2)
self.assertEqual(data["rows"][0]["label"], "首仓")
self.assertEqual(data["market"]["display_symbol"], "ETH/USDT")
@patch("hub_calculator_lib._resolve_market", return_value=_mock_resolve())
def test_trend_calculator_short_rejects_bad_bounds(self, _mock):
data, err = calc_trend_calculator(
direction="short",
capital_usdt=1000,
risk_percent=5,
leverage=5,
entry_price=100,
stop_loss=90,
add_upper=110,
take_profit=80,
dca_legs=3,
)
self.assertIsNone(data)
self.assertIsNotNone(err)
@patch("hub_calculator_lib._resolve_market", return_value=_mock_resolve())
def test_roll_calculator_first_leg_auto(self, _mock):
data, err = calc_roll_calculator(
direction="long",
capital_usdt=1000,
risk_percent=5,
entry_price=100,
stop_loss=95,
take_profit=120,
add_legs=[],
legs_done=0,
)
self.assertIsNone(err)
self.assertIsNotNone(data)
assert data is not None
self.assertEqual(data["first_contracts"], 10.0)
self.assertEqual(len(data["rows"]), 1)
self.assertEqual(data["rows"][0]["loss_at_sl_u"], 50.0)
self.assertEqual(data["rows"][0]["profit_at_tp_u"], 200.0)
@patch("hub_calculator_lib._resolve_market", return_value=_mock_resolve())
def test_roll_calculator_chain_two_legs(self, _mock):
data, err = calc_roll_calculator(
direction="long",
capital_usdt=1000,
risk_percent=5,
entry_price=100,
stop_loss=95,
take_profit=120,
add_legs=[
{"add_price": 105, "new_stop_loss": 98},
{"add_price": 108, "new_stop_loss": 101},
],
legs_done=0,
)
self.assertIsNone(err)
self.assertIsNotNone(data)
assert data is not None
self.assertEqual(len(data["rows"]), 3)
self.assertEqual(data["rows"][1]["label"], "滚仓1")
self.assertGreater(float(data["final_contracts"]), float(data["first_contracts"]))
@patch("hub_calculator_lib._resolve_market", return_value=_mock_resolve())
def test_roll_calculator_rejects_too_many_legs(self, _mock):
data, err = calc_roll_calculator(
direction="long",
capital_usdt=1000,
risk_percent=5,
entry_price=100,
stop_loss=95,
take_profit=120,
add_legs=[
{"add_price": 105, "new_stop_loss": 98},
{"add_price": 108, "new_stop_loss": 101},
{"add_price": 110, "new_stop_loss": 103},
{"add_price": 112, "new_stop_loss": 105},
],
legs_done=0,
)
self.assertIsNone(data)
self.assertIsNotNone(err)
def test_initial_roll_qty(self):
qty, err = calc_initial_roll_qty("long", 100, 95, 50, 1.0)
self.assertIsNone(err)
self.assertEqual(qty, 10.0)
def test_initial_roll_qty_with_contract_size(self):
qty, err = calc_initial_roll_qty("long", 100, 95, 50, 0.1)
self.assertIsNone(err)
self.assertEqual(qty, 100.0)
def test_solve_add_with_contract_size(self):
q2, err = solve_add_amount_for_total_risk(
"long",
qty_existing=10.0,
entry_existing=100.0,
add_price=105.0,
new_stop=98.0,
risk_budget_usdt=50.0,
contract_size=1.0,
)
self.assertIsNone(err)
self.assertIsNotNone(q2)
assert q2 is not None
self.assertGreater(q2, 0)
if __name__ == "__main__":
unittest.main()
"""hub_calculator_lib 测算逻辑。"""
import unittest
from unittest.mock import patch
from lib.hub.hub_calculator_lib import (
calc_initial_roll_qty,
calc_roll_calculator,
calc_trend_calculator,
solve_add_amount_for_total_risk,
)
MOCK_MARKET = {
"exchange_id": "0",
"exchange_key": "binance",
"exchange_name": "币安 · crypto_monitor_binance",
"exchange_label": "币安 · crypto_monitor_binance",
"base": "ETH",
"exchange_symbol": "ETH/USDT:USDT",
"display_symbol": "ETH/USDT",
"contract_size": 1.0,
"price_tick": 0.01,
"price_decimals": 2,
"amount_decimals": 3,
"min_amount": 0.001,
}
def _mock_resolve(_exchange="binance", _base="ETH"):
return MOCK_MARKET, lambda amount: round(float(amount), 3), None
class HubCalculatorLibTests(unittest.TestCase):
@patch("lib.hub.hub_calculator_lib._resolve_market", return_value=_mock_resolve())
def test_trend_calculator_long_basic(self, _mock):
data, err = calc_trend_calculator(
direction="long",
capital_usdt=1000,
risk_percent=5,
leverage=5,
entry_price=100,
stop_loss=95,
add_upper=110,
take_profit=120,
dca_legs=3,
exchange_id="0",
base="ETH",
)
self.assertIsNone(err)
self.assertIsNotNone(data)
assert data is not None
self.assertEqual(data["risk_budget_u"], 50.0)
self.assertGreaterEqual(len(data["rows"]), 2)
self.assertEqual(data["rows"][0]["label"], "首仓")
self.assertEqual(data["market"]["display_symbol"], "ETH/USDT")
@patch("lib.hub.hub_calculator_lib._resolve_market", return_value=_mock_resolve())
def test_trend_calculator_short_rejects_bad_bounds(self, _mock):
data, err = calc_trend_calculator(
direction="short",
capital_usdt=1000,
risk_percent=5,
leverage=5,
entry_price=100,
stop_loss=90,
add_upper=110,
take_profit=80,
dca_legs=3,
)
self.assertIsNone(data)
self.assertIsNotNone(err)
@patch("lib.hub.hub_calculator_lib._resolve_market", return_value=_mock_resolve())
def test_roll_calculator_first_leg_auto(self, _mock):
data, err = calc_roll_calculator(
direction="long",
capital_usdt=1000,
risk_percent=5,
entry_price=100,
stop_loss=95,
take_profit=120,
add_legs=[],
legs_done=0,
)
self.assertIsNone(err)
self.assertIsNotNone(data)
assert data is not None
self.assertEqual(data["first_contracts"], 10.0)
self.assertEqual(len(data["rows"]), 1)
self.assertEqual(data["rows"][0]["loss_at_sl_u"], 50.0)
self.assertEqual(data["rows"][0]["profit_at_tp_u"], 200.0)
@patch("lib.hub.hub_calculator_lib._resolve_market", return_value=_mock_resolve())
def test_roll_calculator_chain_two_legs(self, _mock):
data, err = calc_roll_calculator(
direction="long",
capital_usdt=1000,
risk_percent=5,
entry_price=100,
stop_loss=95,
take_profit=120,
add_legs=[
{"add_price": 105, "new_stop_loss": 98},
{"add_price": 108, "new_stop_loss": 101},
],
legs_done=0,
)
self.assertIsNone(err)
self.assertIsNotNone(data)
assert data is not None
self.assertEqual(len(data["rows"]), 3)
self.assertEqual(data["rows"][1]["label"], "滚仓1")
self.assertGreater(float(data["final_contracts"]), float(data["first_contracts"]))
@patch("lib.hub.hub_calculator_lib._resolve_market", return_value=_mock_resolve())
def test_roll_calculator_rejects_too_many_legs(self, _mock):
data, err = calc_roll_calculator(
direction="long",
capital_usdt=1000,
risk_percent=5,
entry_price=100,
stop_loss=95,
take_profit=120,
add_legs=[
{"add_price": 105, "new_stop_loss": 98},
{"add_price": 108, "new_stop_loss": 101},
{"add_price": 110, "new_stop_loss": 103},
{"add_price": 112, "new_stop_loss": 105},
],
legs_done=0,
)
self.assertIsNone(data)
self.assertIsNotNone(err)
def test_initial_roll_qty(self):
qty, err = calc_initial_roll_qty("long", 100, 95, 50, 1.0)
self.assertIsNone(err)
self.assertEqual(qty, 10.0)
def test_initial_roll_qty_with_contract_size(self):
qty, err = calc_initial_roll_qty("long", 100, 95, 50, 0.1)
self.assertIsNone(err)
self.assertEqual(qty, 100.0)
def test_solve_add_with_contract_size(self):
q2, err = solve_add_amount_for_total_risk(
"long",
qty_existing=10.0,
entry_existing=100.0,
add_price=105.0,
new_stop=98.0,
risk_budget_usdt=50.0,
contract_size=1.0,
)
self.assertIsNone(err)
self.assertIsNotNone(q2)
assert q2 is not None
self.assertGreater(q2, 0)
if __name__ == "__main__":
unittest.main()
+113 -113
View File
@@ -1,113 +1,113 @@
"""hub_calculator_market_lib 合约解析。"""
import unittest
from unittest.mock import patch
from hub_calculator_market_lib import (
amount_decimals_from_exchange,
find_exchange,
get_calculator_market,
list_calculator_exchanges,
make_amount_precise_fn_from_market,
normalize_base_symbol,
resolve_usdt_perp_symbol,
)
class FakeExchange:
def __init__(self, markets: dict):
self.markets = markets
def market(self, symbol: str):
return self.markets[symbol]
def amount_to_precision(self, symbol: str, amount: float) -> str:
return f"{float(amount):.3f}"
class HubCalculatorMarketLibTests(unittest.TestCase):
def test_normalize_base_symbol(self):
self.assertEqual(normalize_base_symbol("eth"), "ETH")
self.assertEqual(normalize_base_symbol("ETH/USDT:USDT"), "ETH")
self.assertEqual(normalize_base_symbol("ETHUSDT"), "ETH")
def test_resolve_usdt_perp_symbol(self):
ex = FakeExchange(
{
"ETH/USDT:USDT": {
"base": "ETH",
"quote": "USDT",
"swap": True,
"active": True,
"contractSize": 1.0,
"limits": {"amount": {"min": 0.001}},
"precision": {"price": 2, "amount": 3},
}
}
)
sym, err = resolve_usdt_perp_symbol(ex, "ETH")
self.assertIsNone(err)
self.assertEqual(sym, "ETH/USDT:USDT")
def test_amount_decimals_from_exchange(self):
ex = FakeExchange({})
self.assertEqual(amount_decimals_from_exchange(ex, "ETH/USDT:USDT"), 3)
def test_make_amount_precise_fn_from_market(self):
fn = make_amount_precise_fn_from_market({"amount_decimals": 3, "min_amount": 0.001})
self.assertEqual(fn(1.23456), 1.234)
self.assertIsNone(fn(0.0001))
@patch.dict("os.environ", {"HUB_BRIDGE_TOKEN": "test-token"}, clear=False)
def test_hub_headers_use_x_hub_token(self):
from hub_calculator_market_lib import _hub_headers
self.assertEqual(_hub_headers(), {"X-Hub-Token": "test-token"})
@patch("hub_calculator_market_lib.fetch_instance_market_sync")
def test_get_calculator_market_from_instance(self, fetch_mock):
fetch_mock.return_value = {
"ok": True,
"base": "ETH",
"exchange_symbol": "ETH/USDT:USDT",
"display_symbol": "ETH/USDT",
"contract_size": 0.01,
"price_tick": 0.01,
"price_decimals": 2,
"amount_decimals": 2,
"min_amount": 0.01,
}
ex = {
"id": "0",
"key": "binance",
"name": "币安 · crypto_monitor_binance",
"enabled": True,
"flask_url": "http://127.0.0.1:5001",
}
data, err = get_calculator_market("0", "ETH", ex=ex)
self.assertIsNone(err)
self.assertIsNotNone(data)
assert data is not None
self.assertEqual(data["exchange_id"], "0")
self.assertEqual(data["exchange_name"], "币安 · crypto_monitor_binance")
self.assertEqual(data["contract_size"], 0.01)
@patch("hub_calculator_market_lib.enabled_exchanges")
def test_list_calculator_exchanges(self, enabled_mock):
enabled_mock.return_value = [
{"id": "0", "key": "binance", "name": "币安", "enabled": True},
]
rows = list_calculator_exchanges()
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["id"], "0")
def test_find_exchange_by_id(self):
with patch(
"hub_calculator_market_lib.load_settings",
return_value={"exchanges": [{"id": "2", "key": "gate", "name": "Gate"}]},
):
self.assertEqual(find_exchange("2")["key"], "gate")
if __name__ == "__main__":
unittest.main()
"""hub_calculator_market_lib 合约解析。"""
import unittest
from unittest.mock import patch
from lib.hub.hub_calculator_market_lib import (
amount_decimals_from_exchange,
find_exchange,
get_calculator_market,
list_calculator_exchanges,
make_amount_precise_fn_from_market,
normalize_base_symbol,
resolve_usdt_perp_symbol,
)
class FakeExchange:
def __init__(self, markets: dict):
self.markets = markets
def market(self, symbol: str):
return self.markets[symbol]
def amount_to_precision(self, symbol: str, amount: float) -> str:
return f"{float(amount):.3f}"
class HubCalculatorMarketLibTests(unittest.TestCase):
def test_normalize_base_symbol(self):
self.assertEqual(normalize_base_symbol("eth"), "ETH")
self.assertEqual(normalize_base_symbol("ETH/USDT:USDT"), "ETH")
self.assertEqual(normalize_base_symbol("ETHUSDT"), "ETH")
def test_resolve_usdt_perp_symbol(self):
ex = FakeExchange(
{
"ETH/USDT:USDT": {
"base": "ETH",
"quote": "USDT",
"swap": True,
"active": True,
"contractSize": 1.0,
"limits": {"amount": {"min": 0.001}},
"precision": {"price": 2, "amount": 3},
}
}
)
sym, err = resolve_usdt_perp_symbol(ex, "ETH")
self.assertIsNone(err)
self.assertEqual(sym, "ETH/USDT:USDT")
def test_amount_decimals_from_exchange(self):
ex = FakeExchange({})
self.assertEqual(amount_decimals_from_exchange(ex, "ETH/USDT:USDT"), 3)
def test_make_amount_precise_fn_from_market(self):
fn = make_amount_precise_fn_from_market({"amount_decimals": 3, "min_amount": 0.001})
self.assertEqual(fn(1.23456), 1.234)
self.assertIsNone(fn(0.0001))
@patch.dict("os.environ", {"HUB_BRIDGE_TOKEN": "test-token"}, clear=False)
def test_hub_headers_use_x_hub_token(self):
from lib.hub.hub_calculator_market_lib import _hub_headers
self.assertEqual(_hub_headers(), {"X-Hub-Token": "test-token"})
@patch("lib.hub.hub_calculator_market_lib.fetch_instance_market_sync")
def test_get_calculator_market_from_instance(self, fetch_mock):
fetch_mock.return_value = {
"ok": True,
"base": "ETH",
"exchange_symbol": "ETH/USDT:USDT",
"display_symbol": "ETH/USDT",
"contract_size": 0.01,
"price_tick": 0.01,
"price_decimals": 2,
"amount_decimals": 2,
"min_amount": 0.01,
}
ex = {
"id": "0",
"key": "binance",
"name": "币安 · crypto_monitor_binance",
"enabled": True,
"flask_url": "http://127.0.0.1:5001",
}
data, err = get_calculator_market("0", "ETH", ex=ex)
self.assertIsNone(err)
self.assertIsNotNone(data)
assert data is not None
self.assertEqual(data["exchange_id"], "0")
self.assertEqual(data["exchange_name"], "币安 · crypto_monitor_binance")
self.assertEqual(data["contract_size"], 0.01)
@patch("lib.hub.hub_calculator_market_lib.enabled_exchanges")
def test_list_calculator_exchanges(self, enabled_mock):
enabled_mock.return_value = [
{"id": "0", "key": "binance", "name": "币安", "enabled": True},
]
rows = list_calculator_exchanges()
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["id"], "0")
def test_find_exchange_by_id(self):
with patch(
"lib.hub.hub_calculator_market_lib.load_settings",
return_value={"exchanges": [{"id": "2", "key": "gate", "name": "Gate"}]},
):
self.assertEqual(find_exchange("2")["key"], "gate")
if __name__ == "__main__":
unittest.main()
+157 -157
View File
@@ -1,157 +1,157 @@
"""开仓计划库:CRUD 与胜率统计。"""
from __future__ import annotations
import tempfile
from pathlib import Path
from hub_entry_plan_lib import (
compute_entry_plan_stats,
create_entry_plan,
delete_entry_plan,
init_db,
list_entry_plans,
normalize_plan_symbol,
resolve_stats_date_bounds,
update_entry_plan,
)
def _base_payload(**overrides):
data = {
"plan_date": "2026-06-14",
"exchange_key": "binance",
"symbol": "BTC",
"plan_type": "trend",
"trend_timeframe": "4h",
"entry_timeframe": "15m",
"direction": "long",
"target_level": "70000",
"current_range": "68000-69000",
"entry_scheme": "breakout",
"note": "test",
}
data.update(overrides)
return data
def test_normalize_plan_symbol():
assert normalize_plan_symbol("btc") == "BTC/USDT"
assert normalize_plan_symbol("ETH/USDT") == "ETH/USDT"
def test_create_without_entry_scheme():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "plans.db"
payload = _base_payload()
del payload["entry_scheme"]
row = create_entry_plan(payload, db_path=db)
assert row["entry_scheme"] == ""
assert row["entry_scheme_label"] == "待填写"
def test_archive_requires_entry_scheme():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "plans.db"
payload = _base_payload()
del payload["entry_scheme"]
row = create_entry_plan(payload, db_path=db)
try:
update_entry_plan(int(row["id"]), {"result": "win"}, db_path=db)
assert False, "expected ValueError"
except ValueError as e:
assert "入场方案" in str(e)
updated = update_entry_plan(
int(row["id"]),
{"entry_scheme": "breakout", "result": "win"},
db_path=db,
)
assert updated["status"] == "archived"
def test_create_list_delete_active_plan():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "plans.db"
row = create_entry_plan(_base_payload(), db_path=db)
assert row["status"] == "active"
assert row["symbol"] == "BTC/USDT"
active = list_entry_plans(status="active", db_path=db)
assert len(active) == 1
assert delete_entry_plan(int(row["id"]), db_path=db) is True
assert list_entry_plans(status="active", db_path=db) == []
def test_archive_on_result():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "plans.db"
row = create_entry_plan(_base_payload(symbol="SOL"), db_path=db)
updated = update_entry_plan(
int(row["id"]),
{"result": "win", "pnl_amount": 12.5},
db_path=db,
)
assert updated["status"] == "archived"
assert updated["result"] == "win"
assert updated["pnl_amount"] == 12.5
assert list_entry_plans(status="active", db_path=db) == []
archived = list_entry_plans(status="archived", db_path=db)
assert len(archived) == 1
def test_archive_without_pnl_amount():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "plans.db"
row = create_entry_plan(_base_payload(symbol="DOGE"), db_path=db)
updated = update_entry_plan(int(row["id"]), {"result": "loss"}, db_path=db)
assert updated["status"] == "archived"
assert updated["pnl_amount"] is None
def test_cannot_delete_archived():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "plans.db"
row = create_entry_plan(_base_payload(), db_path=db)
update_entry_plan(int(row["id"]), {"result": "win"}, db_path=db)
try:
delete_entry_plan(int(row["id"]), db_path=db)
assert False, "expected ValueError"
except ValueError as e:
assert "仅进行中" in str(e)
def test_compute_stats_by_symbol():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "plans.db"
for sym, res in (("BTC", "win"), ("BTC", "loss"), ("ETH", "win")):
row = create_entry_plan(_base_payload(symbol=sym), db_path=db)
update_entry_plan(int(row["id"]), {"result": res}, db_path=db)
stats = compute_entry_plan_stats(dimension="symbol", period="all", db_path=db)
by_sym = {it["key"]: it for it in stats["items"]}
assert by_sym["BTC/USDT"]["win_count"] == 1
assert by_sym["BTC/USDT"]["loss_count"] == 1
assert by_sym["BTC/USDT"]["win_rate"] == 50.0
assert by_sym["ETH/USDT"]["win_count"] == 1
def test_stats_period_range_filter():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "plans.db"
row1 = create_entry_plan(_base_payload(plan_date="2026-06-01"), db_path=db)
row2 = create_entry_plan(_base_payload(plan_date="2026-06-20", symbol="ETH"), db_path=db)
update_entry_plan(int(row1["id"]), {"result": "win"}, db_path=db)
update_entry_plan(int(row2["id"]), {"result": "loss"}, db_path=db)
stats = compute_entry_plan_stats(
dimension="symbol",
period="range",
date_from="2026-06-01",
date_to="2026-06-10",
db_path=db,
)
assert len(stats["items"]) == 1
assert stats["items"][0]["key"] == "BTC/USDT"
def test_resolve_stats_date_bounds():
df, dt, label = resolve_stats_date_bounds(period="all")
assert df is None and dt is None
assert "全部" in label
"""开仓计划库:CRUD 与胜率统计。"""
from __future__ import annotations
import tempfile
from pathlib import Path
from lib.hub.hub_entry_plan_lib import (
compute_entry_plan_stats,
create_entry_plan,
delete_entry_plan,
init_db,
list_entry_plans,
normalize_plan_symbol,
resolve_stats_date_bounds,
update_entry_plan,
)
def _base_payload(**overrides):
data = {
"plan_date": "2026-06-14",
"exchange_key": "binance",
"symbol": "BTC",
"plan_type": "trend",
"trend_timeframe": "4h",
"entry_timeframe": "15m",
"direction": "long",
"target_level": "70000",
"current_range": "68000-69000",
"entry_scheme": "breakout",
"note": "test",
}
data.update(overrides)
return data
def test_normalize_plan_symbol():
assert normalize_plan_symbol("btc") == "BTC/USDT"
assert normalize_plan_symbol("ETH/USDT") == "ETH/USDT"
def test_create_without_entry_scheme():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "plans.db"
payload = _base_payload()
del payload["entry_scheme"]
row = create_entry_plan(payload, db_path=db)
assert row["entry_scheme"] == ""
assert row["entry_scheme_label"] == "待填写"
def test_archive_requires_entry_scheme():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "plans.db"
payload = _base_payload()
del payload["entry_scheme"]
row = create_entry_plan(payload, db_path=db)
try:
update_entry_plan(int(row["id"]), {"result": "win"}, db_path=db)
assert False, "expected ValueError"
except ValueError as e:
assert "入场方案" in str(e)
updated = update_entry_plan(
int(row["id"]),
{"entry_scheme": "breakout", "result": "win"},
db_path=db,
)
assert updated["status"] == "archived"
def test_create_list_delete_active_plan():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "plans.db"
row = create_entry_plan(_base_payload(), db_path=db)
assert row["status"] == "active"
assert row["symbol"] == "BTC/USDT"
active = list_entry_plans(status="active", db_path=db)
assert len(active) == 1
assert delete_entry_plan(int(row["id"]), db_path=db) is True
assert list_entry_plans(status="active", db_path=db) == []
def test_archive_on_result():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "plans.db"
row = create_entry_plan(_base_payload(symbol="SOL"), db_path=db)
updated = update_entry_plan(
int(row["id"]),
{"result": "win", "pnl_amount": 12.5},
db_path=db,
)
assert updated["status"] == "archived"
assert updated["result"] == "win"
assert updated["pnl_amount"] == 12.5
assert list_entry_plans(status="active", db_path=db) == []
archived = list_entry_plans(status="archived", db_path=db)
assert len(archived) == 1
def test_archive_without_pnl_amount():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "plans.db"
row = create_entry_plan(_base_payload(symbol="DOGE"), db_path=db)
updated = update_entry_plan(int(row["id"]), {"result": "loss"}, db_path=db)
assert updated["status"] == "archived"
assert updated["pnl_amount"] is None
def test_cannot_delete_archived():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "plans.db"
row = create_entry_plan(_base_payload(), db_path=db)
update_entry_plan(int(row["id"]), {"result": "win"}, db_path=db)
try:
delete_entry_plan(int(row["id"]), db_path=db)
assert False, "expected ValueError"
except ValueError as e:
assert "仅进行中" in str(e)
def test_compute_stats_by_symbol():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "plans.db"
for sym, res in (("BTC", "win"), ("BTC", "loss"), ("ETH", "win")):
row = create_entry_plan(_base_payload(symbol=sym), db_path=db)
update_entry_plan(int(row["id"]), {"result": res}, db_path=db)
stats = compute_entry_plan_stats(dimension="symbol", period="all", db_path=db)
by_sym = {it["key"]: it for it in stats["items"]}
assert by_sym["BTC/USDT"]["win_count"] == 1
assert by_sym["BTC/USDT"]["loss_count"] == 1
assert by_sym["BTC/USDT"]["win_rate"] == 50.0
assert by_sym["ETH/USDT"]["win_count"] == 1
def test_stats_period_range_filter():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "plans.db"
row1 = create_entry_plan(_base_payload(plan_date="2026-06-01"), db_path=db)
row2 = create_entry_plan(_base_payload(plan_date="2026-06-20", symbol="ETH"), db_path=db)
update_entry_plan(int(row1["id"]), {"result": "win"}, db_path=db)
update_entry_plan(int(row2["id"]), {"result": "loss"}, db_path=db)
stats = compute_entry_plan_stats(
dimension="symbol",
period="range",
date_from="2026-06-01",
date_to="2026-06-10",
db_path=db,
)
assert len(stats["items"]) == 1
assert stats["items"][0]["key"] == "BTC/USDT"
def test_resolve_stats_date_bounds():
df, dt, label = resolve_stats_date_bounds(period="all")
assert df is None and dt is None
assert "全部" in label
+113 -113
View File
@@ -1,113 +1,113 @@
"""hub_fund_history_lib:总资金、回撤与日快照。"""
from __future__ import annotations
from hub_fund_history_lib import (
account_total_usdt,
build_fund_overview,
compute_drawdown,
get_fund_history,
record_fund_snapshot,
)
def test_account_total_requires_both_sides():
assert account_total_usdt(10, 20) == 30.0
assert account_total_usdt(10, None) is None
assert account_total_usdt(None, 5) is None
def test_compute_drawdown():
dd = compute_drawdown([100, 120, 90, 110])
assert dd["peak_usdt"] == 120.0
assert dd["max_drawdown_u"] == 30.0
assert dd["max_drawdown_pct"] == 25.0
def test_build_fund_overview_skips_unmonitored(tmp_path, monkeypatch):
hist_path = tmp_path / "hub_fund_history.json"
monkeypatch.setattr("hub_fund_history_lib.FUND_HISTORY_PATH", hist_path)
record_fund_snapshot(
"2026-06-01",
[
{
"key": "binance",
"name": "Binance",
"funding_usdt": 10,
"trading_usdt": 20,
"monitored": True,
}
],
keep_days=180,
)
record_fund_snapshot(
"2026-06-02",
[
{
"key": "binance",
"name": "Binance",
"funding_usdt": 12,
"trading_usdt": 18,
"monitored": True,
}
],
keep_days=180,
)
exchanges = [
{"id": "0", "key": "binance", "name": "Binance", "enabled": True},
{"id": "3", "key": "gate_bot", "name": "Gate Bot", "enabled": False},
]
board_rows = [
{
"key": "binance",
"name": "Binance",
"account_ok": True,
"funding_usdt": 15,
"trading_usdt": 25,
}
]
out = build_fund_overview(
exchanges,
board_rows=board_rows,
trading_day="2026-06-02",
keep_days=180,
)
assert out["totals"]["total_usdt"] == 40.0
assert out["totals"]["monitored_count"] == 1
assert len(out["accounts"]) == 1
assert all(a["monitored"] for a in out["accounts"])
assert out["totals"]["drawdown"]["max_drawdown_u"] == 0.0
def test_history_start_day_filters_older(tmp_path, monkeypatch):
hist_path = tmp_path / "hub_fund_history.json"
monkeypatch.setattr("hub_fund_history_lib.FUND_HISTORY_PATH", hist_path)
monkeypatch.setattr("hub_fund_history_lib.FUND_HISTORY_START_DAY", "2026-06-09")
record_fund_snapshot(
"2026-06-01",
[
{
"key": "binance",
"name": "Binance",
"funding_usdt": 1,
"trading_usdt": 1,
"monitored": True,
}
],
keep_days=180,
)
record_fund_snapshot(
"2026-06-09",
[
{
"key": "binance",
"name": "Binance",
"funding_usdt": 10,
"trading_usdt": 20,
"monitored": True,
}
],
keep_days=180,
)
hist = get_fund_history(anchor_day="2026-06-10", keep_days=180)
assert "2026-06-01" not in hist
assert "2026-06-09" in hist
"""hub_fund_history_lib:总资金、回撤与日快照。"""
from __future__ import annotations
from lib.hub.hub_fund_history_lib import (
account_total_usdt,
build_fund_overview,
compute_drawdown,
get_fund_history,
record_fund_snapshot,
)
def test_account_total_requires_both_sides():
assert account_total_usdt(10, 20) == 30.0
assert account_total_usdt(10, None) is None
assert account_total_usdt(None, 5) is None
def test_compute_drawdown():
dd = compute_drawdown([100, 120, 90, 110])
assert dd["peak_usdt"] == 120.0
assert dd["max_drawdown_u"] == 30.0
assert dd["max_drawdown_pct"] == 25.0
def test_build_fund_overview_skips_unmonitored(tmp_path, monkeypatch):
hist_path = tmp_path / "hub_fund_history.json"
monkeypatch.setattr("hub_fund_history_lib.FUND_HISTORY_PATH", hist_path)
record_fund_snapshot(
"2026-06-01",
[
{
"key": "binance",
"name": "Binance",
"funding_usdt": 10,
"trading_usdt": 20,
"monitored": True,
}
],
keep_days=180,
)
record_fund_snapshot(
"2026-06-02",
[
{
"key": "binance",
"name": "Binance",
"funding_usdt": 12,
"trading_usdt": 18,
"monitored": True,
}
],
keep_days=180,
)
exchanges = [
{"id": "0", "key": "binance", "name": "Binance", "enabled": True},
{"id": "3", "key": "gate_bot", "name": "Gate Bot", "enabled": False},
]
board_rows = [
{
"key": "binance",
"name": "Binance",
"account_ok": True,
"funding_usdt": 15,
"trading_usdt": 25,
}
]
out = build_fund_overview(
exchanges,
board_rows=board_rows,
trading_day="2026-06-02",
keep_days=180,
)
assert out["totals"]["total_usdt"] == 40.0
assert out["totals"]["monitored_count"] == 1
assert len(out["accounts"]) == 1
assert all(a["monitored"] for a in out["accounts"])
assert out["totals"]["drawdown"]["max_drawdown_u"] == 0.0
def test_history_start_day_filters_older(tmp_path, monkeypatch):
hist_path = tmp_path / "hub_fund_history.json"
monkeypatch.setattr("hub_fund_history_lib.FUND_HISTORY_PATH", hist_path)
monkeypatch.setattr("hub_fund_history_lib.FUND_HISTORY_START_DAY", "2026-06-09")
record_fund_snapshot(
"2026-06-01",
[
{
"key": "binance",
"name": "Binance",
"funding_usdt": 1,
"trading_usdt": 1,
"monitored": True,
}
],
keep_days=180,
)
record_fund_snapshot(
"2026-06-09",
[
{
"key": "binance",
"name": "Binance",
"funding_usdt": 10,
"trading_usdt": 20,
"monitored": True,
}
],
keep_days=180,
)
hist = get_fund_history(anchor_day="2026-06-10", keep_days=180)
assert "2026-06-01" not in hist
assert "2026-06-09" in hist
+58 -58
View File
@@ -1,58 +1,58 @@
"""hub_host_status_lib 单元测试。"""
from __future__ import annotations
import sys
import unittest
from unittest.mock import MagicMock, patch
from hub_host_status_lib import _disk_path, _state, get_host_status
class HubHostStatusLibTest(unittest.TestCase):
def setUp(self):
_state["primed"] = False
_state["net_ts"] = 0.0
_state["net_sent"] = 0
_state["net_recv"] = 0
def test_disk_path_env_override(self):
with patch.dict("os.environ", {"HUB_HOST_DISK_PATH": "/data"}, clear=False):
self.assertEqual(_disk_path(), "/data")
def test_get_host_status_without_psutil(self):
import builtins
real_import = builtins.__import__
def fake_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == "psutil":
raise ImportError("no psutil")
return real_import(name, globals, locals, fromlist, level)
with patch("builtins.__import__", side_effect=fake_import):
out = get_host_status()
self.assertFalse(out.get("ok"))
self.assertIn("psutil", out.get("msg", ""))
def test_get_host_status_payload(self):
fake_vm = MagicMock(total=8_000_000_000, used=3_200_000_000, percent=40.0)
fake_du = MagicMock(total=100_000_000_000, used=50_000_000_000)
fake_net = MagicMock(bytes_sent=1_000_000, bytes_recv=2_000_000)
fake_psutil = MagicMock()
fake_psutil.cpu_percent.return_value = 12.5
fake_psutil.cpu_count.return_value = 4
fake_psutil.virtual_memory.return_value = fake_vm
fake_psutil.disk_usage.return_value = fake_du
fake_psutil.net_io_counters.return_value = fake_net
fake_psutil.boot_time.return_value = 1_700_000_000.0
with patch.dict(sys.modules, {"psutil": fake_psutil}):
out = get_host_status()
self.assertTrue(out.get("ok"))
self.assertEqual(out["cpu"]["percent"], 12.5)
self.assertEqual(out["memory"]["percent"], 40.0)
self.assertEqual(out["disk"]["percent"], 50.0)
self.assertIn("network", out)
if __name__ == "__main__":
unittest.main()
"""hub_host_status_lib 单元测试。"""
from __future__ import annotations
import sys
import unittest
from unittest.mock import MagicMock, patch
from lib.hub.hub_host_status_lib import _disk_path, _state, get_host_status
class HubHostStatusLibTest(unittest.TestCase):
def setUp(self):
_state["primed"] = False
_state["net_ts"] = 0.0
_state["net_sent"] = 0
_state["net_recv"] = 0
def test_disk_path_env_override(self):
with patch.dict("os.environ", {"HUB_HOST_DISK_PATH": "/data"}, clear=False):
self.assertEqual(_disk_path(), "/data")
def test_get_host_status_without_psutil(self):
import builtins
real_import = builtins.__import__
def fake_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == "psutil":
raise ImportError("no psutil")
return real_import(name, globals, locals, fromlist, level)
with patch("builtins.__import__", side_effect=fake_import):
out = get_host_status()
self.assertFalse(out.get("ok"))
self.assertIn("psutil", out.get("msg", ""))
def test_get_host_status_payload(self):
fake_vm = MagicMock(total=8_000_000_000, used=3_200_000_000, percent=40.0)
fake_du = MagicMock(total=100_000_000_000, used=50_000_000_000)
fake_net = MagicMock(bytes_sent=1_000_000, bytes_recv=2_000_000)
fake_psutil = MagicMock()
fake_psutil.cpu_percent.return_value = 12.5
fake_psutil.cpu_count.return_value = 4
fake_psutil.virtual_memory.return_value = fake_vm
fake_psutil.disk_usage.return_value = fake_du
fake_psutil.net_io_counters.return_value = fake_net
fake_psutil.boot_time.return_value = 1_700_000_000.0
with patch.dict(sys.modules, {"psutil": fake_psutil}):
out = get_host_status()
self.assertTrue(out.get("ok"))
self.assertEqual(out["cpu"]["percent"], 12.5)
self.assertEqual(out["memory"]["percent"], 40.0)
self.assertEqual(out["disk"]["percent"], 50.0)
self.assertIn("network", out)
if __name__ == "__main__":
unittest.main()
+466 -466
View File
@@ -1,466 +1,466 @@
"""中控 K 线库:分周期保留、聚合与分页读取。"""
from __future__ import annotations
import tempfile
import time
import unittest
from pathlib import Path
from hub_kline_store import (
HUB_KLINE_REMOTE_FETCH_CAP,
_since_ms_for_span,
clear_series_bars,
init_db,
load_bars_before,
load_bars_latest,
purge_retention,
purge_timeframe_by_days,
resolve_chart_bars,
retention_days,
trim_contiguous_tail,
upsert_bars,
)
from hub_ohlcv_lib import (
TIMEFRAME_MS,
bar_limit_for_timeframe,
chart_fetch_start_ms,
chart_initial_limit,
last_closed_bar_open_ms,
window_start_ms,
)
class TestHubKlineStore(unittest.TestCase):
def setUp(self):
self.tmp = tempfile.TemporaryDirectory()
self.db = Path(self.tmp.name) / "test_hub_kline.db"
def tearDown(self):
self.tmp.cleanup()
def test_bar_limits(self):
self.assertEqual(bar_limit_for_timeframe("5m"), 5000)
self.assertEqual(bar_limit_for_timeframe("1h"), 1000)
self.assertEqual(bar_limit_for_timeframe("1d"), 1000)
self.assertEqual(bar_limit_for_timeframe("1w"), 500)
self.assertEqual(chart_initial_limit("5m"), 2000)
self.assertEqual(chart_initial_limit("1h"), 1000)
self.assertEqual(chart_initial_limit("1d"), 500)
def test_chart_fetch_window_exceeds_retention(self):
now = int(time.time() * 1000)
need = bar_limit_for_timeframe("1d")
fetch_start = chart_fetch_start_ms("1d", need, now)
db_start = window_start_ms("1d", need, retention_days(), now)
self.assertLess(fetch_start, db_start)
def test_purge_retention_5m_one_year(self):
init_db(self.db)
old_ms = int(time.time() * 1000) - 400 * 86400000
upsert_bars(
"okx",
"BTC/USDT",
"5m",
[
{
"open_time_ms": old_ms,
"open": 1,
"high": 2,
"low": 0.5,
"close": 1.5,
"volume": 10,
}
],
self.db,
)
n = purge_timeframe_by_days("5m", 365, self.db)
self.assertGreaterEqual(n, 1)
rows = load_bars_latest("okx", "BTC/USDT", "5m", 10, self.db)
self.assertEqual(len(rows), 0)
def test_purge_retention_keeps_1d(self):
init_db(self.db)
old_ms = int(time.time() * 1000) - 400 * 86400000
upsert_bars(
"okx",
"BTC/USDT",
"1d",
[
{
"open_time_ms": old_ms,
"open": 1,
"high": 2,
"low": 0.5,
"close": 1.5,
"volume": 10,
}
],
self.db,
)
purge_retention(self.db)
rows = load_bars_latest("okx", "BTC/USDT", "1d", 10, self.db)
self.assertEqual(len(rows), 1)
def test_resolve_uses_cache_without_remote(self):
init_db(self.db)
now = int(time.time() * 1000)
tf = "5m"
period = TIMEFRAME_MS[tf]
last_closed = last_closed_bar_open_ms(tf, now)
bars = []
for i in range(400):
oms = last_closed - (399 - i) * period
bars.append(
{
"open_time_ms": oms,
"open": 100 + i,
"high": 101 + i,
"low": 99 + i,
"close": 100.5 + i,
"volume": 1000 + i,
}
)
upsert_bars("okx", "ETH/USDT", tf, bars, self.db)
def remote_fetch(**kwargs):
self.fail("不应请求交易所")
out = resolve_chart_bars(
"okx",
"ETH/USDT",
tf,
remote_fetch,
db_path=self.db,
limit=300,
)
self.assertTrue(out.get("ok"))
self.assertEqual(len(out.get("candles") or []), 300)
def test_resolve_15m_reads_native_bars(self):
init_db(self.db)
now = int(time.time() * 1000)
period = TIMEFRAME_MS["15m"]
last_closed = last_closed_bar_open_ms("15m", now)
bars = []
for i in range(12):
oms = last_closed - (11 - i) * period
bars.append(
{
"open_time_ms": oms,
"open": 1.0 + i,
"high": 2.0 + i,
"low": 0.5 + i,
"close": 1.5 + i,
"volume": 10.0,
}
)
upsert_bars("okx", "ETH/USDT", "15m", bars, self.db)
def remote_fetch(**kwargs):
self.fail("不应请求交易所")
out = resolve_chart_bars(
"okx",
"ETH/USDT",
"15m",
remote_fetch,
db_path=self.db,
limit=10,
)
self.assertTrue(out.get("ok"))
self.assertEqual(out.get("source"), "db")
self.assertEqual(out.get("storage_timeframe"), "15m")
self.assertGreaterEqual(len(out.get("candles") or []), 10)
def test_load_bars_before(self):
init_db(self.db)
period = TIMEFRAME_MS["1h"]
base = 1_700_000_000_000
bars = []
for i in range(5):
bars.append(
{
"open_time_ms": base + i * period,
"open": 1,
"high": 2,
"low": 0.5,
"close": 1.5,
"volume": 1,
}
)
upsert_bars("okx", "BTC/USDT", "1h", bars, self.db)
before = base + 3 * period
got = load_bars_before("okx", "BTC/USDT", "1h", before, 2, self.db)
self.assertEqual(len(got), 2)
self.assertEqual(got[-1]["open_time_ms"], base + 2 * period)
def test_trim_contiguous_tail_drops_orphan_prefix(self):
period = TIMEFRAME_MS["15m"]
base_old = 1_700_000_000_000
base_new = base_old + period * 500
bars = []
for i in range(3):
bars.append(
{
"open_time_ms": base_old + i * period,
"open": 1,
"high": 2,
"low": 0.5,
"close": 1.5,
"volume": 1,
}
)
for i in range(5):
bars.append(
{
"open_time_ms": base_new + i * period,
"open": 2,
"high": 3,
"low": 1.5,
"close": 2.5,
"volume": 2,
}
)
trimmed, split = trim_contiguous_tail(bars, period)
self.assertEqual(split, 3)
self.assertEqual(len(trimmed), 5)
self.assertEqual(trimmed[0]["open_time_ms"], base_new)
def test_resolve_drops_discontinuous_orphans(self):
init_db(self.db)
period = TIMEFRAME_MS["15m"]
now = int(time.time() * 1000)
old_ms = now - period * 800
upsert_bars(
"okx",
"ONDO/USDT",
"15m",
[
{
"open_time_ms": old_ms,
"open": 0.33,
"high": 0.34,
"low": 0.32,
"close": 0.335,
"volume": 100,
}
],
self.db,
)
recent = []
start = now - period * 20
for i in range(20):
recent.append(
{
"open_time_ms": start + i * period,
"open": 0.35,
"high": 0.36,
"low": 0.34,
"close": 0.355,
"volume": 50,
}
)
def remote_fetch(**kwargs):
return {"ok": True, "bars": recent, "price_tick": 0.0001}
out = resolve_chart_bars(
"okx",
"ONDO/USDT",
"15m",
remote_fetch,
db_path=self.db,
limit=50,
)
self.assertTrue(out.get("ok"))
candles = out.get("candles") or []
self.assertGreaterEqual(len(candles), 19)
if len(candles) >= 2:
for i in range(1, len(candles)):
gap = candles[i]["time"] - candles[i - 1]["time"]
self.assertLessEqual(gap, int(period / 1000 * 3.0))
def test_resolve_refetches_when_db_has_discontinuous_full_count(self):
init_db(self.db)
period = TIMEFRAME_MS["15m"]
now = int(time.time() * 1000)
old_start = now - period * 3000
recent_start = now - period * 25
old_bars = [
{
"open_time_ms": old_start + i * period,
"open": 62000,
"high": 62100,
"low": 61900,
"close": 62050,
"volume": 10,
}
for i in range(500)
]
recent = [
{
"open_time_ms": recent_start + i * period,
"open": 104000,
"high": 104100,
"low": 103900,
"close": 104050,
"volume": 20,
}
for i in range(30)
]
upsert_bars("binance", "BTC/USDT", "15m", old_bars, self.db)
upsert_bars("binance", "BTC/USDT", "15m", recent, self.db)
fetch_calls = []
def remote_fetch(**kwargs):
fetch_calls.append(dict(kwargs))
full = []
start = now - period * 120
for i in range(120):
full.append(
{
"open_time_ms": start + i * period,
"open": 104000 + i,
"high": 104100 + i,
"low": 103900 + i,
"close": 104050 + i,
"volume": 30,
}
)
return {"ok": True, "bars": full, "price_tick": 0.01}
out = resolve_chart_bars(
"binance",
"BTC/USDT",
"15m",
remote_fetch,
db_path=self.db,
limit=2000,
)
self.assertTrue(out.get("ok"))
self.assertGreater(len(fetch_calls), 0)
self.assertGreaterEqual(len(out.get("candles") or []), 100)
self.assertGreater(int(out.get("fetched") or 0), 0)
def test_clear_series_and_force_refetch(self):
init_db(self.db)
period = TIMEFRAME_MS["5m"]
now = int(time.time() * 1000)
stale = [
{
"open_time_ms": now - period * (i + 100),
"open": 1,
"high": 2,
"low": 0.5,
"close": 1.5,
"volume": 1,
}
for i in range(40)
]
upsert_bars("binance", "BTC/USDT", "5m", stale, self.db)
self.assertEqual(len(load_bars_latest("binance", "BTC/USDT", "5m", 100, self.db)), 40)
removed = clear_series_bars("binance", "BTC/USDT", "5m", self.db)
self.assertEqual(removed, 40)
self.assertEqual(len(load_bars_latest("binance", "BTC/USDT", "5m", 100, self.db)), 0)
fresh = [
{
"open_time_ms": now - period * (20 - i),
"open": 10,
"high": 11,
"low": 9,
"close": 10.5,
"volume": 2,
}
for i in range(20)
]
def remote_fetch(**kwargs):
return {"ok": True, "bars": fresh, "price_tick": 0.01}
out = resolve_chart_bars(
"binance",
"BTC/USDT",
"5m",
remote_fetch,
db_path=self.db,
force_refresh=True,
clear_db=True,
limit=50,
)
self.assertTrue(out.get("ok"))
self.assertGreaterEqual(int(out.get("cleared") or 0), 0)
self.assertGreater(int(out.get("fetched") or 0), 0)
self.assertGreaterEqual(len(out.get("candles") or []), 19)
def test_since_span_matches_fetch_limit_not_need(self):
period = TIMEFRAME_MS["15m"]
now_ms = 1_800_000_000_000
fetch_limit = HUB_KLINE_REMOTE_FETCH_CAP
since = _since_ms_for_span(
now_ms=now_ms,
period_ms=period,
span_bars=fetch_limit,
cutoff_ms=0,
)
self.assertEqual(since, now_ms - period * fetch_limit)
wrong_since = now_ms - period * chart_initial_limit("15m")
self.assertGreater(since, wrong_since)
def test_thin_series_tail_refresh_fetches_full_window(self):
init_db(self.db)
period = TIMEFRAME_MS["15m"]
now = int(time.time() * 1000)
last_closed = last_closed_bar_open_ms("15m", now)
bars = [
{
"open_time_ms": last_closed - period * (150 - i),
"open": 100000,
"high": 100100,
"low": 99900,
"close": 100050,
"volume": 1,
}
for i in range(150)
]
fetch_calls: list[dict] = []
def remote_fetch(**kwargs):
fetch_calls.append(dict(kwargs))
return {"ok": True, "bars": bars, "price_tick": 0.01}
out = resolve_chart_bars(
"binance",
"BTC/USDT",
"15m",
remote_fetch,
db_path=self.db,
tail_refresh=True,
)
self.assertTrue(out.get("ok"))
self.assertGreaterEqual(len(out.get("candles") or []), 100)
self.assertGreater(int(out.get("fetched") or 0), 0)
self.assertTrue(any(int(c.get("limit") or 0) > 30 for c in fetch_calls))
def test_resolve_before_ms_exhausted(self):
init_db(self.db)
def remote_fetch(**kwargs):
return {"ok": False, "msg": "no remote"}
out = resolve_chart_bars(
"okx",
"BTC/USDT",
"5m",
remote_fetch,
db_path=self.db,
limit=100,
before_ms=int(time.time() * 1000),
)
self.assertTrue(out.get("ok"))
self.assertEqual(out.get("candles"), [])
self.assertTrue(out.get("exhausted"))
if __name__ == "__main__":
unittest.main()
"""中控 K 线库:分周期保留、聚合与分页读取。"""
from __future__ import annotations
import tempfile
import time
import unittest
from pathlib import Path
from lib.hub.hub_kline_store import (
HUB_KLINE_REMOTE_FETCH_CAP,
_since_ms_for_span,
clear_series_bars,
init_db,
load_bars_before,
load_bars_latest,
purge_retention,
purge_timeframe_by_days,
resolve_chart_bars,
retention_days,
trim_contiguous_tail,
upsert_bars,
)
from lib.hub.hub_ohlcv_lib import (
TIMEFRAME_MS,
bar_limit_for_timeframe,
chart_fetch_start_ms,
chart_initial_limit,
last_closed_bar_open_ms,
window_start_ms,
)
class TestHubKlineStore(unittest.TestCase):
def setUp(self):
self.tmp = tempfile.TemporaryDirectory()
self.db = Path(self.tmp.name) / "test_hub_kline.db"
def tearDown(self):
self.tmp.cleanup()
def test_bar_limits(self):
self.assertEqual(bar_limit_for_timeframe("5m"), 5000)
self.assertEqual(bar_limit_for_timeframe("1h"), 1000)
self.assertEqual(bar_limit_for_timeframe("1d"), 1000)
self.assertEqual(bar_limit_for_timeframe("1w"), 500)
self.assertEqual(chart_initial_limit("5m"), 2000)
self.assertEqual(chart_initial_limit("1h"), 1000)
self.assertEqual(chart_initial_limit("1d"), 500)
def test_chart_fetch_window_exceeds_retention(self):
now = int(time.time() * 1000)
need = bar_limit_for_timeframe("1d")
fetch_start = chart_fetch_start_ms("1d", need, now)
db_start = window_start_ms("1d", need, retention_days(), now)
self.assertLess(fetch_start, db_start)
def test_purge_retention_5m_one_year(self):
init_db(self.db)
old_ms = int(time.time() * 1000) - 400 * 86400000
upsert_bars(
"okx",
"BTC/USDT",
"5m",
[
{
"open_time_ms": old_ms,
"open": 1,
"high": 2,
"low": 0.5,
"close": 1.5,
"volume": 10,
}
],
self.db,
)
n = purge_timeframe_by_days("5m", 365, self.db)
self.assertGreaterEqual(n, 1)
rows = load_bars_latest("okx", "BTC/USDT", "5m", 10, self.db)
self.assertEqual(len(rows), 0)
def test_purge_retention_keeps_1d(self):
init_db(self.db)
old_ms = int(time.time() * 1000) - 400 * 86400000
upsert_bars(
"okx",
"BTC/USDT",
"1d",
[
{
"open_time_ms": old_ms,
"open": 1,
"high": 2,
"low": 0.5,
"close": 1.5,
"volume": 10,
}
],
self.db,
)
purge_retention(self.db)
rows = load_bars_latest("okx", "BTC/USDT", "1d", 10, self.db)
self.assertEqual(len(rows), 1)
def test_resolve_uses_cache_without_remote(self):
init_db(self.db)
now = int(time.time() * 1000)
tf = "5m"
period = TIMEFRAME_MS[tf]
last_closed = last_closed_bar_open_ms(tf, now)
bars = []
for i in range(400):
oms = last_closed - (399 - i) * period
bars.append(
{
"open_time_ms": oms,
"open": 100 + i,
"high": 101 + i,
"low": 99 + i,
"close": 100.5 + i,
"volume": 1000 + i,
}
)
upsert_bars("okx", "ETH/USDT", tf, bars, self.db)
def remote_fetch(**kwargs):
self.fail("不应请求交易所")
out = resolve_chart_bars(
"okx",
"ETH/USDT",
tf,
remote_fetch,
db_path=self.db,
limit=300,
)
self.assertTrue(out.get("ok"))
self.assertEqual(len(out.get("candles") or []), 300)
def test_resolve_15m_reads_native_bars(self):
init_db(self.db)
now = int(time.time() * 1000)
period = TIMEFRAME_MS["15m"]
last_closed = last_closed_bar_open_ms("15m", now)
bars = []
for i in range(12):
oms = last_closed - (11 - i) * period
bars.append(
{
"open_time_ms": oms,
"open": 1.0 + i,
"high": 2.0 + i,
"low": 0.5 + i,
"close": 1.5 + i,
"volume": 10.0,
}
)
upsert_bars("okx", "ETH/USDT", "15m", bars, self.db)
def remote_fetch(**kwargs):
self.fail("不应请求交易所")
out = resolve_chart_bars(
"okx",
"ETH/USDT",
"15m",
remote_fetch,
db_path=self.db,
limit=10,
)
self.assertTrue(out.get("ok"))
self.assertEqual(out.get("source"), "db")
self.assertEqual(out.get("storage_timeframe"), "15m")
self.assertGreaterEqual(len(out.get("candles") or []), 10)
def test_load_bars_before(self):
init_db(self.db)
period = TIMEFRAME_MS["1h"]
base = 1_700_000_000_000
bars = []
for i in range(5):
bars.append(
{
"open_time_ms": base + i * period,
"open": 1,
"high": 2,
"low": 0.5,
"close": 1.5,
"volume": 1,
}
)
upsert_bars("okx", "BTC/USDT", "1h", bars, self.db)
before = base + 3 * period
got = load_bars_before("okx", "BTC/USDT", "1h", before, 2, self.db)
self.assertEqual(len(got), 2)
self.assertEqual(got[-1]["open_time_ms"], base + 2 * period)
def test_trim_contiguous_tail_drops_orphan_prefix(self):
period = TIMEFRAME_MS["15m"]
base_old = 1_700_000_000_000
base_new = base_old + period * 500
bars = []
for i in range(3):
bars.append(
{
"open_time_ms": base_old + i * period,
"open": 1,
"high": 2,
"low": 0.5,
"close": 1.5,
"volume": 1,
}
)
for i in range(5):
bars.append(
{
"open_time_ms": base_new + i * period,
"open": 2,
"high": 3,
"low": 1.5,
"close": 2.5,
"volume": 2,
}
)
trimmed, split = trim_contiguous_tail(bars, period)
self.assertEqual(split, 3)
self.assertEqual(len(trimmed), 5)
self.assertEqual(trimmed[0]["open_time_ms"], base_new)
def test_resolve_drops_discontinuous_orphans(self):
init_db(self.db)
period = TIMEFRAME_MS["15m"]
now = int(time.time() * 1000)
old_ms = now - period * 800
upsert_bars(
"okx",
"ONDO/USDT",
"15m",
[
{
"open_time_ms": old_ms,
"open": 0.33,
"high": 0.34,
"low": 0.32,
"close": 0.335,
"volume": 100,
}
],
self.db,
)
recent = []
start = now - period * 20
for i in range(20):
recent.append(
{
"open_time_ms": start + i * period,
"open": 0.35,
"high": 0.36,
"low": 0.34,
"close": 0.355,
"volume": 50,
}
)
def remote_fetch(**kwargs):
return {"ok": True, "bars": recent, "price_tick": 0.0001}
out = resolve_chart_bars(
"okx",
"ONDO/USDT",
"15m",
remote_fetch,
db_path=self.db,
limit=50,
)
self.assertTrue(out.get("ok"))
candles = out.get("candles") or []
self.assertGreaterEqual(len(candles), 19)
if len(candles) >= 2:
for i in range(1, len(candles)):
gap = candles[i]["time"] - candles[i - 1]["time"]
self.assertLessEqual(gap, int(period / 1000 * 3.0))
def test_resolve_refetches_when_db_has_discontinuous_full_count(self):
init_db(self.db)
period = TIMEFRAME_MS["15m"]
now = int(time.time() * 1000)
old_start = now - period * 3000
recent_start = now - period * 25
old_bars = [
{
"open_time_ms": old_start + i * period,
"open": 62000,
"high": 62100,
"low": 61900,
"close": 62050,
"volume": 10,
}
for i in range(500)
]
recent = [
{
"open_time_ms": recent_start + i * period,
"open": 104000,
"high": 104100,
"low": 103900,
"close": 104050,
"volume": 20,
}
for i in range(30)
]
upsert_bars("binance", "BTC/USDT", "15m", old_bars, self.db)
upsert_bars("binance", "BTC/USDT", "15m", recent, self.db)
fetch_calls = []
def remote_fetch(**kwargs):
fetch_calls.append(dict(kwargs))
full = []
start = now - period * 120
for i in range(120):
full.append(
{
"open_time_ms": start + i * period,
"open": 104000 + i,
"high": 104100 + i,
"low": 103900 + i,
"close": 104050 + i,
"volume": 30,
}
)
return {"ok": True, "bars": full, "price_tick": 0.01}
out = resolve_chart_bars(
"binance",
"BTC/USDT",
"15m",
remote_fetch,
db_path=self.db,
limit=2000,
)
self.assertTrue(out.get("ok"))
self.assertGreater(len(fetch_calls), 0)
self.assertGreaterEqual(len(out.get("candles") or []), 100)
self.assertGreater(int(out.get("fetched") or 0), 0)
def test_clear_series_and_force_refetch(self):
init_db(self.db)
period = TIMEFRAME_MS["5m"]
now = int(time.time() * 1000)
stale = [
{
"open_time_ms": now - period * (i + 100),
"open": 1,
"high": 2,
"low": 0.5,
"close": 1.5,
"volume": 1,
}
for i in range(40)
]
upsert_bars("binance", "BTC/USDT", "5m", stale, self.db)
self.assertEqual(len(load_bars_latest("binance", "BTC/USDT", "5m", 100, self.db)), 40)
removed = clear_series_bars("binance", "BTC/USDT", "5m", self.db)
self.assertEqual(removed, 40)
self.assertEqual(len(load_bars_latest("binance", "BTC/USDT", "5m", 100, self.db)), 0)
fresh = [
{
"open_time_ms": now - period * (20 - i),
"open": 10,
"high": 11,
"low": 9,
"close": 10.5,
"volume": 2,
}
for i in range(20)
]
def remote_fetch(**kwargs):
return {"ok": True, "bars": fresh, "price_tick": 0.01}
out = resolve_chart_bars(
"binance",
"BTC/USDT",
"5m",
remote_fetch,
db_path=self.db,
force_refresh=True,
clear_db=True,
limit=50,
)
self.assertTrue(out.get("ok"))
self.assertGreaterEqual(int(out.get("cleared") or 0), 0)
self.assertGreater(int(out.get("fetched") or 0), 0)
self.assertGreaterEqual(len(out.get("candles") or []), 19)
def test_since_span_matches_fetch_limit_not_need(self):
period = TIMEFRAME_MS["15m"]
now_ms = 1_800_000_000_000
fetch_limit = HUB_KLINE_REMOTE_FETCH_CAP
since = _since_ms_for_span(
now_ms=now_ms,
period_ms=period,
span_bars=fetch_limit,
cutoff_ms=0,
)
self.assertEqual(since, now_ms - period * fetch_limit)
wrong_since = now_ms - period * chart_initial_limit("15m")
self.assertGreater(since, wrong_since)
def test_thin_series_tail_refresh_fetches_full_window(self):
init_db(self.db)
period = TIMEFRAME_MS["15m"]
now = int(time.time() * 1000)
last_closed = last_closed_bar_open_ms("15m", now)
bars = [
{
"open_time_ms": last_closed - period * (150 - i),
"open": 100000,
"high": 100100,
"low": 99900,
"close": 100050,
"volume": 1,
}
for i in range(150)
]
fetch_calls: list[dict] = []
def remote_fetch(**kwargs):
fetch_calls.append(dict(kwargs))
return {"ok": True, "bars": bars, "price_tick": 0.01}
out = resolve_chart_bars(
"binance",
"BTC/USDT",
"15m",
remote_fetch,
db_path=self.db,
tail_refresh=True,
)
self.assertTrue(out.get("ok"))
self.assertGreaterEqual(len(out.get("candles") or []), 100)
self.assertGreater(int(out.get("fetched") or 0), 0)
self.assertTrue(any(int(c.get("limit") or 0) > 30 for c in fetch_calls))
def test_resolve_before_ms_exhausted(self):
init_db(self.db)
def remote_fetch(**kwargs):
return {"ok": False, "msg": "no remote"}
out = resolve_chart_bars(
"okx",
"BTC/USDT",
"5m",
remote_fetch,
db_path=self.db,
limit=100,
before_ms=int(time.time() * 1000),
)
self.assertTrue(out.get("ok"))
self.assertEqual(out.get("candles"), [])
self.assertTrue(out.get("exhausted"))
if __name__ == "__main__":
unittest.main()
+73 -73
View File
@@ -1,73 +1,73 @@
import os
import tempfile
import unittest
from pathlib import Path
from unittest import mock
from hub_macro_calendar_lib import (
build_banner_message,
create_event,
delete_event,
enrich_alert,
init_db,
list_active_alerts,
list_events,
update_event,
)
class HubMacroCalendarLibTests(unittest.TestCase):
def setUp(self):
self.tmp = tempfile.TemporaryDirectory()
self.db_path = Path(self.tmp.name) / "macro.db"
init_db(self.db_path)
def tearDown(self):
self.tmp.cleanup()
def test_create_and_list(self):
row = create_event("cpi", "2026-06-18 20:30", note="核心CPI", db_path=self.db_path)
self.assertEqual(row["event_type"], "cpi")
self.assertEqual(row["event_at"], "2026-06-18 20:30")
rows = list_events(now_ms=row["event_at_ms"] - 86400000, db_path=self.db_path)
self.assertEqual(len(rows), 1)
def test_duplicate_rejected(self):
create_event("fomc", "2026-07-01 02:00", db_path=self.db_path)
with self.assertRaises(ValueError):
create_event("fomc", "2026-07-01 02:00", db_path=self.db_path)
def test_active_window_and_messages(self):
row = create_event("employment", "2026-06-18 20:30", db_path=self.db_path)
t0 = int(row["event_at_ms"])
inside = enrich_alert(row, now_ms=t0 - 30 * 60 * 1000)
self.assertIsNotNone(inside)
self.assertEqual(inside["phase"], "imminent")
outside = enrich_alert(row, now_ms=t0 - 2 * 3600 * 1000)
self.assertIsNone(outside)
alerts = list_active_alerts(now_ms=t0 + 15 * 60 * 1000, db_path=self.db_path)
self.assertEqual(len(alerts), 1)
msg_pos = build_banner_message(alerts[0], has_positions=True)
msg_flat = build_banner_message(alerts[0], has_positions=False)
self.assertIn("注意仓位风险", msg_pos)
self.assertIn("建议等待", msg_flat)
def test_update_and_delete(self):
row = create_event("cpi", "2026-06-18 20:30", db_path=self.db_path)
updated = update_event(
row["id"],
event_at="2026-06-18 21:00",
note="修正时间",
db_path=self.db_path,
)
self.assertEqual(updated["event_at"], "2026-06-18 21:00")
self.assertTrue(delete_event(row["id"], db_path=self.db_path))
self.assertEqual(len(list_events(now_ms=updated["event_at_ms"], db_path=self.db_path)), 0)
def test_invalid_type(self):
with self.assertRaises(ValueError):
create_event("nfp", "2026-06-18 20:30", db_path=self.db_path)
if __name__ == "__main__":
unittest.main()
import os
import tempfile
import unittest
from pathlib import Path
from unittest import mock
from lib.hub.hub_macro_calendar_lib import (
build_banner_message,
create_event,
delete_event,
enrich_alert,
init_db,
list_active_alerts,
list_events,
update_event,
)
class HubMacroCalendarLibTests(unittest.TestCase):
def setUp(self):
self.tmp = tempfile.TemporaryDirectory()
self.db_path = Path(self.tmp.name) / "macro.db"
init_db(self.db_path)
def tearDown(self):
self.tmp.cleanup()
def test_create_and_list(self):
row = create_event("cpi", "2026-06-18 20:30", note="核心CPI", db_path=self.db_path)
self.assertEqual(row["event_type"], "cpi")
self.assertEqual(row["event_at"], "2026-06-18 20:30")
rows = list_events(now_ms=row["event_at_ms"] - 86400000, db_path=self.db_path)
self.assertEqual(len(rows), 1)
def test_duplicate_rejected(self):
create_event("fomc", "2026-07-01 02:00", db_path=self.db_path)
with self.assertRaises(ValueError):
create_event("fomc", "2026-07-01 02:00", db_path=self.db_path)
def test_active_window_and_messages(self):
row = create_event("employment", "2026-06-18 20:30", db_path=self.db_path)
t0 = int(row["event_at_ms"])
inside = enrich_alert(row, now_ms=t0 - 30 * 60 * 1000)
self.assertIsNotNone(inside)
self.assertEqual(inside["phase"], "imminent")
outside = enrich_alert(row, now_ms=t0 - 2 * 3600 * 1000)
self.assertIsNone(outside)
alerts = list_active_alerts(now_ms=t0 + 15 * 60 * 1000, db_path=self.db_path)
self.assertEqual(len(alerts), 1)
msg_pos = build_banner_message(alerts[0], has_positions=True)
msg_flat = build_banner_message(alerts[0], has_positions=False)
self.assertIn("注意仓位风险", msg_pos)
self.assertIn("建议等待", msg_flat)
def test_update_and_delete(self):
row = create_event("cpi", "2026-06-18 20:30", db_path=self.db_path)
updated = update_event(
row["id"],
event_at="2026-06-18 21:00",
note="修正时间",
db_path=self.db_path,
)
self.assertEqual(updated["event_at"], "2026-06-18 21:00")
self.assertTrue(delete_event(row["id"], db_path=self.db_path))
self.assertEqual(len(list_events(now_ms=updated["event_at_ms"], db_path=self.db_path)), 0)
def test_invalid_type(self):
with self.assertRaises(ValueError):
create_event("nfp", "2026-06-18 20:30", db_path=self.db_path)
if __name__ == "__main__":
unittest.main()
+39 -39
View File
@@ -1,39 +1,39 @@
"""hub /api/hub/monitorenrich 局部返回时须保留 keys。"""
from __future__ import annotations
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from hub_bridge import build_hub_monitor_payload # noqa: E402
class TestHubMonitorPayload(unittest.TestCase):
def test_partial_enrich_keeps_keys(self):
keys = [{"id": 7, "symbol": "BTC/USDT"}]
orders = [{"id": 1}]
trends = [{"id": 9, "symbol": "ETH/USDT"}]
rolls = []
def enrich_only_trends(**_kw):
return {"trends": [{"id": 9, "add_count": 2}]}
out = build_hub_monitor_payload(
keys=keys,
orders=orders,
trends=trends,
rolls=rolls,
enrich=enrich_only_trends,
)
self.assertTrue(out["ok"])
self.assertEqual(out["keys"], keys)
self.assertEqual(out["orders"], orders)
self.assertEqual(out["rolls"], rolls)
self.assertEqual(out["trends"][0]["add_count"], 2)
if __name__ == "__main__":
unittest.main()
"""hub /api/hub/monitorenrich 局部返回时须保留 keys。"""
from __future__ import annotations
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from lib.hub.hub_bridge import build_hub_monitor_payload # noqa: E402
class TestHubMonitorPayload(unittest.TestCase):
def test_partial_enrich_keeps_keys(self):
keys = [{"id": 7, "symbol": "BTC/USDT"}]
orders = [{"id": 1}]
trends = [{"id": 9, "symbol": "ETH/USDT"}]
rolls = []
def enrich_only_trends(**_kw):
return {"trends": [{"id": 9, "add_count": 2}]}
out = build_hub_monitor_payload(
keys=keys,
orders=orders,
trends=trends,
rolls=rolls,
enrich=enrich_only_trends,
)
self.assertTrue(out["ok"])
self.assertEqual(out["keys"], keys)
self.assertEqual(out["orders"], orders)
self.assertEqual(out["rolls"], rolls)
self.assertEqual(out["trends"][0]["add_count"], 2)
if __name__ == "__main__":
unittest.main()
+222 -223
View File
@@ -1,223 +1,222 @@
"""hub_ohlcv_lib:分页拉取(Gate 等单次不足 chunk 时仍继续)。"""
from __future__ import annotations
import unittest
from hub_ohlcv_lib import (
aggregate_ohlcv_bars,
bars_spacing_matches_timeframe,
fetch_ohlcv_for_hub,
normalize_price_tick,
)
class _FakeExchange:
def __init__(self, pages, *, timeframes=None):
self.pages = list(pages)
self.calls = []
self.markets = {}
self.timeframes = timeframes if timeframes is not None else {}
def fetch_ohlcv(self, symbol, timeframe=None, since=None, limit=None):
self.calls.append(
{"symbol": symbol, "since": since, "limit": limit, "timeframe": timeframe}
)
if not self.pages:
return []
page = self.pages.pop(0)
if since is None:
return page
return [b for b in page if b[0] >= since]
class TestHubOhlcvLib(unittest.TestCase):
def test_normalize_price_tick_snaps_powers_of_ten(self):
self.assertAlmostEqual(normalize_price_tick(0.00001), 0.00001)
self.assertAlmostEqual(normalize_price_tick(0.001), 0.001)
self.assertIsNone(normalize_price_tick(0))
def test_price_tick_from_decimal_precision(self):
class _Ex:
markets = {"BTC/USDT:USDT": {"precision": {"price": 2}, "info": {}, "limits": {}}}
def load_markets(self):
return self.markets
def market(self, sym):
return self.markets[sym]
def price_to_precision(self, sym, price):
return "12345.67"
tick = __import__("hub_ohlcv_lib", fromlist=["price_tick_from_market"]).price_tick_from_market(
_Ex(), "BTC/USDT:USDT"
)
self.assertAlmostEqual(tick, 0.01)
def test_price_tick_from_binance_price_filter(self):
class _Ex:
markets = {
"BTC/USDT:USDT": {
"precision": {"price": 2},
"info": {
"filters": [
{"filterType": "PRICE_FILTER", "tickSize": "0.10"},
{"filterType": "LOT_SIZE", "stepSize": "0.001"},
]
},
"limits": {},
}
}
def load_markets(self):
return self.markets
def market(self, sym):
return self.markets[sym]
def price_to_precision(self, sym, price):
return "12345.6"
from hub_ohlcv_lib import price_tick_from_market
tick = price_tick_from_market(_Ex(), "BTC/USDT:USDT")
self.assertAlmostEqual(tick, 0.10)
def test_price_tick_from_info_tick_size(self):
class _Ex:
markets = {
"INJ/USDT:USDT": {
"precision": {"price": 4},
"info": {"tickSize": "0.001"},
"limits": {},
}
}
def load_markets(self):
return self.markets
def market(self, sym):
return self.markets[sym]
def price_to_precision(self, sym, price):
return "7.123"
from hub_ohlcv_lib import price_tick_from_market
tick = price_tick_from_market(_Ex(), "INJ/USDT:USDT")
self.assertAlmostEqual(tick, 0.001)
def test_full_fetch_without_since_paginates_okx_style(self):
"""OKX 等无 since 单次约 300 根,须分页至 limit。"""
from hub_ohlcv_lib import TIMEFRAME_MS
step = TIMEFRAME_MS["1h"]
want = 1000
base = max(0, int(__import__("time").time() * 1000) - want * step)
pages = [
[[base + i * step, 1.0, 1.1, 0.9, 1.05, 100.0] for i in range(300)],
[[base + (300 + i) * step, 2.0, 2.1, 1.9, 2.05, 200.0] for i in range(300)],
[[base + (600 + i) * step, 3.0, 3.1, 2.9, 3.05, 300.0] for i in range(300)],
[[base + (900 + i) * step, 4.0, 4.1, 3.9, 4.05, 400.0] for i in range(100)],
]
ex = _FakeExchange(pages)
out = fetch_ohlcv_for_hub(
symbol="ONDO/USDT",
timeframe="1h",
since_ms=None,
limit=want,
normalize_symbol_input=lambda s: str(s).strip().upper(),
normalize_exchange_symbol=lambda s: f"{s}:USDT" if ":" not in s else s,
ensure_markets_loaded=lambda: None,
exchange=ex,
)
self.assertTrue(out.get("ok"))
self.assertEqual(len(out.get("bars") or []), 1000)
self.assertGreaterEqual(len(ex.calls), 4)
self.assertAlmostEqual(out["bars"][-1]["close"], 4.05)
def test_pagination_continues_when_page_smaller_than_chunk(self):
"""Gate 等常返回 299 根/次,不应误判为已到末尾。"""
base = 1_700_000_000_000
step = 4 * 60 * 60 * 1000
page1 = [
[base + i * step, 1.0, 1.1, 0.9, 1.05, 100.0] for i in range(299)
]
page2 = [
[base + (299 + i) * step, 2.0, 2.1, 1.9, 2.05, 200.0] for i in range(299)
]
page3 = [
[base + (598 + i) * step, 3.0, 3.1, 2.9, 3.05, 300.0] for i in range(50)
]
ex = _FakeExchange([page1, page2, page3])
out = fetch_ohlcv_for_hub(
symbol="INJ/USDT",
timeframe="4h",
since_ms=base,
limit=600,
normalize_symbol_input=lambda s: str(s).strip().upper(),
normalize_exchange_symbol=lambda s: f"{s}:USDT" if ":" not in s else s,
ensure_markets_loaded=lambda: None,
exchange=ex,
)
self.assertTrue(out.get("ok"))
self.assertEqual(len(out.get("bars") or []), 600)
self.assertGreaterEqual(len(ex.calls), 3)
self.assertAlmostEqual(out["bars"][-1]["close"], 3.05)
def test_pagination_stops_when_next_since_reaches_now(self):
"""Gate 等:分页 since 不得越过当前时间,避免 from>to。"""
from hub_ohlcv_lib import TIMEFRAME_MS
step = TIMEFRAME_MS["1d"]
now_ms = int(__import__("time").time() * 1000)
# 最后一页最后一根 K 的 next_since 将 >= now_ms,应停止不再请求
last_open = ((now_ms // step) - 2) * step
page = [
[last_open - step, 1.0, 1.1, 0.9, 1.0, 10.0],
[last_open, 1.1, 1.2, 1.0, 1.1, 11.0],
]
ex = _FakeExchange([page])
out = fetch_ohlcv_for_hub(
symbol="ONDO/USDT",
timeframe="1d",
since_ms=last_open - step * 5,
limit=10,
normalize_symbol_input=lambda s: str(s).strip().upper(),
normalize_exchange_symbol=lambda s: f"{s}:USDT" if ":" not in s else s,
ensure_markets_loaded=lambda: None,
exchange=ex,
)
self.assertTrue(out.get("ok"))
self.assertGreaterEqual(len(out.get("bars") or []), 2)
self.assertLessEqual(len(ex.calls), 4)
def test_aggregate_ohlcv_bars_buckets(self):
from hub_ohlcv_lib import TIMEFRAME_MS
h1 = TIMEFRAME_MS["1h"]
h4 = TIMEFRAME_MS["4h"]
base = (1_700_000_000_000 // h4) * h4
src = [
{
"open_time_ms": base + i * h1,
"open": 1.0,
"high": 2.0,
"low": 0.5,
"close": 1.5,
"volume": 1.0,
}
for i in range(4)
]
out = aggregate_ohlcv_bars(src, "4h")
self.assertEqual(len(out), 1)
self.assertEqual(out[0]["volume"], 4.0)
self.assertEqual(out[0]["high"], 2.0)
self.assertEqual(out[0]["low"], 0.5)
if __name__ == "__main__":
unittest.main()
"""hub_ohlcv_lib:分页拉取(Gate 等单次不足 chunk 时仍继续)。"""
from __future__ import annotations
import unittest
from lib.hub.hub_ohlcv_lib import (
aggregate_ohlcv_bars,
bars_spacing_matches_timeframe,
fetch_ohlcv_for_hub,
normalize_price_tick,
price_tick_from_market,
)
class _FakeExchange:
def __init__(self, pages, *, timeframes=None):
self.pages = list(pages)
self.calls = []
self.markets = {}
self.timeframes = timeframes if timeframes is not None else {}
def fetch_ohlcv(self, symbol, timeframe=None, since=None, limit=None):
self.calls.append(
{"symbol": symbol, "since": since, "limit": limit, "timeframe": timeframe}
)
if not self.pages:
return []
page = self.pages.pop(0)
if since is None:
return page
return [b for b in page if b[0] >= since]
class TestHubOhlcvLib(unittest.TestCase):
def test_normalize_price_tick_snaps_powers_of_ten(self):
self.assertAlmostEqual(normalize_price_tick(0.00001), 0.00001)
self.assertAlmostEqual(normalize_price_tick(0.001), 0.001)
self.assertIsNone(normalize_price_tick(0))
def test_price_tick_from_decimal_precision(self):
class _Ex:
markets = {"BTC/USDT:USDT": {"precision": {"price": 2}, "info": {}, "limits": {}}}
def load_markets(self):
return self.markets
def market(self, sym):
return self.markets[sym]
def price_to_precision(self, sym, price):
return "12345.67"
tick = price_tick_from_market(_Ex(), "BTC/USDT:USDT")
self.assertAlmostEqual(tick, 0.01)
def test_price_tick_from_binance_price_filter(self):
class _Ex:
markets = {
"BTC/USDT:USDT": {
"precision": {"price": 2},
"info": {
"filters": [
{"filterType": "PRICE_FILTER", "tickSize": "0.10"},
{"filterType": "LOT_SIZE", "stepSize": "0.001"},
]
},
"limits": {},
}
}
def load_markets(self):
return self.markets
def market(self, sym):
return self.markets[sym]
def price_to_precision(self, sym, price):
return "12345.6"
from lib.hub.hub_ohlcv_lib import price_tick_from_market
tick = price_tick_from_market(_Ex(), "BTC/USDT:USDT")
self.assertAlmostEqual(tick, 0.10)
def test_price_tick_from_info_tick_size(self):
class _Ex:
markets = {
"INJ/USDT:USDT": {
"precision": {"price": 4},
"info": {"tickSize": "0.001"},
"limits": {},
}
}
def load_markets(self):
return self.markets
def market(self, sym):
return self.markets[sym]
def price_to_precision(self, sym, price):
return "7.123"
from lib.hub.hub_ohlcv_lib import price_tick_from_market
tick = price_tick_from_market(_Ex(), "INJ/USDT:USDT")
self.assertAlmostEqual(tick, 0.001)
def test_full_fetch_without_since_paginates_okx_style(self):
"""OKX 等无 since 单次约 300 根,须分页至 limit。"""
from lib.hub.hub_ohlcv_lib import TIMEFRAME_MS
step = TIMEFRAME_MS["1h"]
want = 1000
base = max(0, int(__import__("time").time() * 1000) - want * step)
pages = [
[[base + i * step, 1.0, 1.1, 0.9, 1.05, 100.0] for i in range(300)],
[[base + (300 + i) * step, 2.0, 2.1, 1.9, 2.05, 200.0] for i in range(300)],
[[base + (600 + i) * step, 3.0, 3.1, 2.9, 3.05, 300.0] for i in range(300)],
[[base + (900 + i) * step, 4.0, 4.1, 3.9, 4.05, 400.0] for i in range(100)],
]
ex = _FakeExchange(pages)
out = fetch_ohlcv_for_hub(
symbol="ONDO/USDT",
timeframe="1h",
since_ms=None,
limit=want,
normalize_symbol_input=lambda s: str(s).strip().upper(),
normalize_exchange_symbol=lambda s: f"{s}:USDT" if ":" not in s else s,
ensure_markets_loaded=lambda: None,
exchange=ex,
)
self.assertTrue(out.get("ok"))
self.assertEqual(len(out.get("bars") or []), 1000)
self.assertGreaterEqual(len(ex.calls), 4)
self.assertAlmostEqual(out["bars"][-1]["close"], 4.05)
def test_pagination_continues_when_page_smaller_than_chunk(self):
"""Gate 等常返回 299 根/次,不应误判为已到末尾。"""
base = 1_700_000_000_000
step = 4 * 60 * 60 * 1000
page1 = [
[base + i * step, 1.0, 1.1, 0.9, 1.05, 100.0] for i in range(299)
]
page2 = [
[base + (299 + i) * step, 2.0, 2.1, 1.9, 2.05, 200.0] for i in range(299)
]
page3 = [
[base + (598 + i) * step, 3.0, 3.1, 2.9, 3.05, 300.0] for i in range(50)
]
ex = _FakeExchange([page1, page2, page3])
out = fetch_ohlcv_for_hub(
symbol="INJ/USDT",
timeframe="4h",
since_ms=base,
limit=600,
normalize_symbol_input=lambda s: str(s).strip().upper(),
normalize_exchange_symbol=lambda s: f"{s}:USDT" if ":" not in s else s,
ensure_markets_loaded=lambda: None,
exchange=ex,
)
self.assertTrue(out.get("ok"))
self.assertEqual(len(out.get("bars") or []), 600)
self.assertGreaterEqual(len(ex.calls), 3)
self.assertAlmostEqual(out["bars"][-1]["close"], 3.05)
def test_pagination_stops_when_next_since_reaches_now(self):
"""Gate 等:分页 since 不得越过当前时间,避免 from>to。"""
from lib.hub.hub_ohlcv_lib import TIMEFRAME_MS
step = TIMEFRAME_MS["1d"]
now_ms = int(__import__("time").time() * 1000)
# 最后一页最后一根 K 的 next_since 将 >= now_ms,应停止不再请求
last_open = ((now_ms // step) - 2) * step
page = [
[last_open - step, 1.0, 1.1, 0.9, 1.0, 10.0],
[last_open, 1.1, 1.2, 1.0, 1.1, 11.0],
]
ex = _FakeExchange([page])
out = fetch_ohlcv_for_hub(
symbol="ONDO/USDT",
timeframe="1d",
since_ms=last_open - step * 5,
limit=10,
normalize_symbol_input=lambda s: str(s).strip().upper(),
normalize_exchange_symbol=lambda s: f"{s}:USDT" if ":" not in s else s,
ensure_markets_loaded=lambda: None,
exchange=ex,
)
self.assertTrue(out.get("ok"))
self.assertGreaterEqual(len(out.get("bars") or []), 2)
self.assertLessEqual(len(ex.calls), 4)
def test_aggregate_ohlcv_bars_buckets(self):
from lib.hub.hub_ohlcv_lib import TIMEFRAME_MS
h1 = TIMEFRAME_MS["1h"]
h4 = TIMEFRAME_MS["4h"]
base = (1_700_000_000_000 // h4) * h4
src = [
{
"open_time_ms": base + i * h1,
"open": 1.0,
"high": 2.0,
"low": 0.5,
"close": 1.5,
"volume": 1.0,
}
for i in range(4)
]
out = aggregate_ohlcv_bars(src, "4h")
self.assertEqual(len(out), 1)
self.assertEqual(out[0]["volume"], 4.0)
self.assertEqual(out[0]["high"], 2.0)
self.assertEqual(out[0]["low"], 0.5)
if __name__ == "__main__":
unittest.main()
+6 -1
View File
@@ -5,7 +5,12 @@ import json
import sys
from pathlib import Path
import pytest
try:
import pytest
except ImportError: # pragma: no cover
import unittest
raise unittest.SkipTest("pytest not installed")
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
+348 -348
View File
@@ -1,348 +1,348 @@
"""币种档案库:5m 聚合与视窗计算。"""
from __future__ import annotations
import tempfile
from pathlib import Path
from hub_ohlcv_lib import aggregate_ohlcv_bars
from datetime import datetime, timezone
from zoneinfo import ZoneInfo
from hub_symbol_archive_lib import (
CHART_DISPLAY_TZ,
_compute_period_stats,
_fill_missing_bars,
init_db,
list_daily_trades,
load_symbol_trades,
ms_to_wall_clock_str,
parse_wall_clock_ms,
resolve_archive_chart,
trading_day_bounds_ms,
upsert_bars_5m,
upsert_trade_overlay,
list_symbol_rows,
upsert_trades_cache,
)
def _seed_5m_bars(
db: Path,
start_ms: int,
count: int,
step: int = 300_000,
*,
ex: str = "gate",
sym: str = "ONDO",
) -> None:
bars = []
price = 1.0
for i in range(count):
o = start_ms + i * step
price += 0.001
bars.append(
{
"open_time_ms": o,
"open": price,
"high": price + 0.002,
"low": price - 0.001,
"close": price + 0.001,
"volume": 100 + i,
}
)
upsert_bars_5m(ex, sym, bars, db_path=db)
def test_aggregate_15m_from_5m():
start = 1_700_000_000_000
bars = []
for i in range(6):
t = start + i * 300_000
bars.append(
{
"open_time_ms": t,
"open": 1.0,
"high": 1.1,
"low": 0.9,
"close": 1.05,
"volume": 10,
}
)
agg = aggregate_ohlcv_bars(bars, "15m")
assert len(agg) >= 1
assert agg[-1]["close"] == bars[-1]["close"]
assert agg[0]["open_time_ms"] <= agg[1]["open_time_ms"]
def test_resolve_archive_chart_15m():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "archive.db"
init_db(db)
anchor = 1_700_000_000_000
_seed_5m_bars(db, anchor - 50 * 300_000, 120)
out = resolve_archive_chart(
"gate",
"ONDO",
"15m",
anchor_ms=anchor,
mode="hold",
bars=40,
db_path=db,
)
assert out["ok"] is True
assert out["timeframe"] == "15m"
assert len(out["candles"]) >= 10
def test_fill_missing_bars_continuity():
period = 300_000
start = (1_700_000_000_000 // period) * period
bars = [
{
"open_time_ms": start,
"open": 1.0,
"high": 1.1,
"low": 0.9,
"close": 1.05,
"volume": 10,
},
{
"open_time_ms": start + period * 2,
"open": 1.05,
"high": 1.15,
"low": 1.0,
"close": 1.1,
"volume": 8,
},
]
filled = _fill_missing_bars(bars, period, start, start + period * 2)
assert len(filled) >= 3
assert any(b.get("filled") for b in filled)
def test_resolve_archive_chart_history_range():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "archive.db"
init_db(db)
open_ms = 1_700_000_000_000
close_ms = open_ms + 6 * 3600_000
_seed_5m_bars(db, open_ms - 20 * 300_000, 200, ex="gate", sym="BNB/USDT")
out = resolve_archive_chart(
"gate",
"BNB/USDT",
"15m",
opened_ms=open_ms,
closed_ms=close_ms,
mode="hold",
range_mode="history",
db_path=db,
)
assert out["ok"] is True
assert out.get("range_mode") == "history"
assert out.get("window_end_ms") <= close_ms + 4 * 3600_000
assert len(out["candles"]) >= 40
def test_sync_prunes_missing_trades():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "archive.db"
init_db(db)
upsert_trades_cache(
"gate",
[
{"id": 1, "symbol": "BNB/USDT", "result": "止损", "pnl_amount": -1},
{"id": 2, "symbol": "BNB/USDT", "result": "止盈", "pnl_amount": 1},
],
db_path=db,
prune_missing=False,
)
stats = upsert_trades_cache(
"gate",
[{"id": 1, "symbol": "BNB/USDT", "result": "止损", "pnl_amount": -1}],
db_path=db,
prune_missing=True,
)
rows = load_symbol_trades("gate", "BNB/USDT", db_path=db)
assert len(rows) == 1
assert rows[0]["trade_id"] == 1
assert stats["removed"] == 1
def test_list_with_overlay_filters():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "archive.db"
init_db(db)
upsert_trades_cache(
"gate",
[
{
"id": 1,
"symbol": "ONDO",
"direction": "long",
"result": "止盈",
"pnl_amount": 12.5,
"opened_at": "2026-01-01 10:00:00",
"closed_at": "2026-01-01 12:00:00",
"opened_at_ms": 1_700_000_000_000,
"closed_at_ms": 1_700_007_200_000,
},
{
"id": 2,
"symbol": "ONDO",
"direction": "short",
"result": "止损",
"pnl_amount": -3.2,
"opened_at": "2026-01-02 10:00:00",
"closed_at": "2026-01-02 11:00:00",
"opened_at_ms": 1_700_086_400_000,
"closed_at_ms": 1_700_090_000_000,
},
],
db_path=db,
)
upsert_trade_overlay("gate", 2, behavior_tag="sick", note="追高", db_path=db)
rows = list_symbol_rows(db_path=db)
assert len(rows) == 1
assert rows[0]["trade_count"] == 2
sick_only = list_symbol_rows(filter_sick=True, db_path=db)
assert len(sick_only) == 1
profit_only = list_symbol_rows(filter_profit=True, db_path=db)
assert len(profit_only) == 1
def test_parse_wall_clock_ms_uses_utc_plus_8():
ms = parse_wall_clock_ms("2026-06-07 20:30:00")
assert ms is not None
dt_utc = datetime.fromtimestamp(ms / 1000.0, tz=timezone.utc)
dt_bj = dt_utc.astimezone(CHART_DISPLAY_TZ)
assert dt_bj.strftime("%Y-%m-%d %H:%M:%S") == "2026-06-07 20:30:00"
assert ms_to_wall_clock_str(ms) == "2026-06-07 20:30:00"
assert parse_wall_clock_ms("2026-06-07 20:30") == ms
def test_parse_wall_clock_ms_accepts_epoch_strings():
ms = 1_700_000_000_000
assert parse_wall_clock_ms(str(ms)) == ms
assert parse_wall_clock_ms(str(ms // 1000)) == ms
def test_resolve_archive_chart_history_uses_trade_span_not_200_bars():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "archive.db"
init_db(db)
opened = 1_700_000_000_000
closed = opened + 20 * 24 * 3600_000
_seed_5m_bars(db, opened - 35 * 24 * 3600_000, 40 * 24 * 12)
out = resolve_archive_chart(
"gate",
"ONDO",
"15m",
opened_ms=opened,
closed_ms=closed,
mode="hold",
bars=200,
range_mode="history",
db_path=db,
)
assert out["ok"] is True
assert out["range_mode"] == "history"
assert out["bar_count"] > 200
def test_upsert_forces_sync_exchange_key():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "archive.db"
init_db(db)
upsert_trades_cache(
"gate_bot",
[
{
"id": 77,
"exchange_key": "gate",
"account_exchange_key": "gate",
"symbol": "ETH/USDT",
"result": "止损",
"pnl_amount": -1,
"opened_at_ms": 1_700_000_000_000,
"closed_at_ms": 1_700_007_200_000,
}
],
db_path=db,
)
rows = load_symbol_trades("gate_bot", "ETH/USDT", db_path=db)
assert len(rows) == 1
assert rows[0]["exchange_key"] == "gate_bot"
assert "account_exchange_key" not in rows[0]
def test_compute_period_stats_win_loss_metrics():
rows = [
{"exchange_key": "binance", "pnl_amount": 10.0, "behavior_tag": ""},
{"exchange_key": "binance", "pnl_amount": 4.0, "behavior_tag": ""},
{"exchange_key": "okx", "pnl_amount": -3.0, "behavior_tag": "sick"},
{"exchange_key": "okx", "pnl_amount": -6.0, "behavior_tag": ""},
]
st = _compute_period_stats(rows)
assert st["open_count"] == 4
assert st["win_count"] == 2
assert st["loss_count"] == 2
assert st["avg_win"] == 7.0
assert st["avg_loss"] == -4.5
assert st["max_win"] == 10.0
assert st["max_loss"] == -6.0
assert st["win_rate"] == 50.0
assert st["profit_loss_ratio"] == round(7.0 / 4.5, 2)
assert st["sick_count"] == 1
assert st["pnl_total"] == 5.0
assert st["pnl_ex_sick"] == 8.0
assert st["by_exchange"]["binance"]["win_count"] == 2
assert st["by_exchange"]["binance"]["win_rate"] == 100.0
assert st["by_exchange"]["binance"]["profit_loss_ratio"] is None
def test_list_daily_trades_search_filters_stats():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "archive.db"
init_db(db)
day = "2023-11-15"
start_ms, _ = trading_day_bounds_ms(day)
btc_close = start_ms + 3_600_000
eth_close = start_ms + 7_200_000
upsert_trades_cache(
"gate",
[
{
"id": 1,
"symbol": "BTC/USDT",
"result": "止盈",
"pnl_amount": 5.0,
"opened_at_ms": start_ms,
"closed_at_ms": btc_close,
},
{
"id": 2,
"symbol": "ETH/USDT",
"result": "止损",
"pnl_amount": -2.0,
"opened_at_ms": btc_close,
"closed_at_ms": eth_close,
},
],
db_path=db,
)
payload = list_daily_trades(
period="range",
date_from=day,
date_to=day,
search="btc",
db_path=db,
)
assert len(payload["trades"]) == 1
assert payload["trades"][0]["symbol"] == "BTC/USDT"
st = payload["stats"]
assert st["open_count"] == 1
assert st["win_count"] == 1
assert st["loss_count"] == 0
assert st["max_win"] == 5.0
assert st["pnl_total"] == 5.0
"""币种档案库:5m 聚合与视窗计算。"""
from __future__ import annotations
import tempfile
from pathlib import Path
from lib.hub.hub_ohlcv_lib import aggregate_ohlcv_bars
from datetime import datetime, timezone
from zoneinfo import ZoneInfo
from lib.hub.hub_symbol_archive_lib import (
CHART_DISPLAY_TZ,
_compute_period_stats,
_fill_missing_bars,
init_db,
list_daily_trades,
load_symbol_trades,
ms_to_wall_clock_str,
parse_wall_clock_ms,
resolve_archive_chart,
trading_day_bounds_ms,
upsert_bars_5m,
upsert_trade_overlay,
list_symbol_rows,
upsert_trades_cache,
)
def _seed_5m_bars(
db: Path,
start_ms: int,
count: int,
step: int = 300_000,
*,
ex: str = "gate",
sym: str = "ONDO",
) -> None:
bars = []
price = 1.0
for i in range(count):
o = start_ms + i * step
price += 0.001
bars.append(
{
"open_time_ms": o,
"open": price,
"high": price + 0.002,
"low": price - 0.001,
"close": price + 0.001,
"volume": 100 + i,
}
)
upsert_bars_5m(ex, sym, bars, db_path=db)
def test_aggregate_15m_from_5m():
start = 1_700_000_000_000
bars = []
for i in range(6):
t = start + i * 300_000
bars.append(
{
"open_time_ms": t,
"open": 1.0,
"high": 1.1,
"low": 0.9,
"close": 1.05,
"volume": 10,
}
)
agg = aggregate_ohlcv_bars(bars, "15m")
assert len(agg) >= 1
assert agg[-1]["close"] == bars[-1]["close"]
assert agg[0]["open_time_ms"] <= agg[1]["open_time_ms"]
def test_resolve_archive_chart_15m():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "archive.db"
init_db(db)
anchor = 1_700_000_000_000
_seed_5m_bars(db, anchor - 50 * 300_000, 120)
out = resolve_archive_chart(
"gate",
"ONDO",
"15m",
anchor_ms=anchor,
mode="hold",
bars=40,
db_path=db,
)
assert out["ok"] is True
assert out["timeframe"] == "15m"
assert len(out["candles"]) >= 10
def test_fill_missing_bars_continuity():
period = 300_000
start = (1_700_000_000_000 // period) * period
bars = [
{
"open_time_ms": start,
"open": 1.0,
"high": 1.1,
"low": 0.9,
"close": 1.05,
"volume": 10,
},
{
"open_time_ms": start + period * 2,
"open": 1.05,
"high": 1.15,
"low": 1.0,
"close": 1.1,
"volume": 8,
},
]
filled = _fill_missing_bars(bars, period, start, start + period * 2)
assert len(filled) >= 3
assert any(b.get("filled") for b in filled)
def test_resolve_archive_chart_history_range():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "archive.db"
init_db(db)
open_ms = 1_700_000_000_000
close_ms = open_ms + 6 * 3600_000
_seed_5m_bars(db, open_ms - 20 * 300_000, 200, ex="gate", sym="BNB/USDT")
out = resolve_archive_chart(
"gate",
"BNB/USDT",
"15m",
opened_ms=open_ms,
closed_ms=close_ms,
mode="hold",
range_mode="history",
db_path=db,
)
assert out["ok"] is True
assert out.get("range_mode") == "history"
assert out.get("window_end_ms") <= close_ms + 4 * 3600_000
assert len(out["candles"]) >= 40
def test_sync_prunes_missing_trades():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "archive.db"
init_db(db)
upsert_trades_cache(
"gate",
[
{"id": 1, "symbol": "BNB/USDT", "result": "止损", "pnl_amount": -1},
{"id": 2, "symbol": "BNB/USDT", "result": "止盈", "pnl_amount": 1},
],
db_path=db,
prune_missing=False,
)
stats = upsert_trades_cache(
"gate",
[{"id": 1, "symbol": "BNB/USDT", "result": "止损", "pnl_amount": -1}],
db_path=db,
prune_missing=True,
)
rows = load_symbol_trades("gate", "BNB/USDT", db_path=db)
assert len(rows) == 1
assert rows[0]["trade_id"] == 1
assert stats["removed"] == 1
def test_list_with_overlay_filters():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "archive.db"
init_db(db)
upsert_trades_cache(
"gate",
[
{
"id": 1,
"symbol": "ONDO",
"direction": "long",
"result": "止盈",
"pnl_amount": 12.5,
"opened_at": "2026-01-01 10:00:00",
"closed_at": "2026-01-01 12:00:00",
"opened_at_ms": 1_700_000_000_000,
"closed_at_ms": 1_700_007_200_000,
},
{
"id": 2,
"symbol": "ONDO",
"direction": "short",
"result": "止损",
"pnl_amount": -3.2,
"opened_at": "2026-01-02 10:00:00",
"closed_at": "2026-01-02 11:00:00",
"opened_at_ms": 1_700_086_400_000,
"closed_at_ms": 1_700_090_000_000,
},
],
db_path=db,
)
upsert_trade_overlay("gate", 2, behavior_tag="sick", note="追高", db_path=db)
rows = list_symbol_rows(db_path=db)
assert len(rows) == 1
assert rows[0]["trade_count"] == 2
sick_only = list_symbol_rows(filter_sick=True, db_path=db)
assert len(sick_only) == 1
profit_only = list_symbol_rows(filter_profit=True, db_path=db)
assert len(profit_only) == 1
def test_parse_wall_clock_ms_uses_utc_plus_8():
ms = parse_wall_clock_ms("2026-06-07 20:30:00")
assert ms is not None
dt_utc = datetime.fromtimestamp(ms / 1000.0, tz=timezone.utc)
dt_bj = dt_utc.astimezone(CHART_DISPLAY_TZ)
assert dt_bj.strftime("%Y-%m-%d %H:%M:%S") == "2026-06-07 20:30:00"
assert ms_to_wall_clock_str(ms) == "2026-06-07 20:30:00"
assert parse_wall_clock_ms("2026-06-07 20:30") == ms
def test_parse_wall_clock_ms_accepts_epoch_strings():
ms = 1_700_000_000_000
assert parse_wall_clock_ms(str(ms)) == ms
assert parse_wall_clock_ms(str(ms // 1000)) == ms
def test_resolve_archive_chart_history_uses_trade_span_not_200_bars():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "archive.db"
init_db(db)
opened = 1_700_000_000_000
closed = opened + 20 * 24 * 3600_000
_seed_5m_bars(db, opened - 35 * 24 * 3600_000, 40 * 24 * 12)
out = resolve_archive_chart(
"gate",
"ONDO",
"15m",
opened_ms=opened,
closed_ms=closed,
mode="hold",
bars=200,
range_mode="history",
db_path=db,
)
assert out["ok"] is True
assert out["range_mode"] == "history"
assert out["bar_count"] > 200
def test_upsert_forces_sync_exchange_key():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "archive.db"
init_db(db)
upsert_trades_cache(
"gate_bot",
[
{
"id": 77,
"exchange_key": "gate",
"account_exchange_key": "gate",
"symbol": "ETH/USDT",
"result": "止损",
"pnl_amount": -1,
"opened_at_ms": 1_700_000_000_000,
"closed_at_ms": 1_700_007_200_000,
}
],
db_path=db,
)
rows = load_symbol_trades("gate_bot", "ETH/USDT", db_path=db)
assert len(rows) == 1
assert rows[0]["exchange_key"] == "gate_bot"
assert "account_exchange_key" not in rows[0]
def test_compute_period_stats_win_loss_metrics():
rows = [
{"exchange_key": "binance", "pnl_amount": 10.0, "behavior_tag": ""},
{"exchange_key": "binance", "pnl_amount": 4.0, "behavior_tag": ""},
{"exchange_key": "okx", "pnl_amount": -3.0, "behavior_tag": "sick"},
{"exchange_key": "okx", "pnl_amount": -6.0, "behavior_tag": ""},
]
st = _compute_period_stats(rows)
assert st["open_count"] == 4
assert st["win_count"] == 2
assert st["loss_count"] == 2
assert st["avg_win"] == 7.0
assert st["avg_loss"] == -4.5
assert st["max_win"] == 10.0
assert st["max_loss"] == -6.0
assert st["win_rate"] == 50.0
assert st["profit_loss_ratio"] == round(7.0 / 4.5, 2)
assert st["sick_count"] == 1
assert st["pnl_total"] == 5.0
assert st["pnl_ex_sick"] == 8.0
assert st["by_exchange"]["binance"]["win_count"] == 2
assert st["by_exchange"]["binance"]["win_rate"] == 100.0
assert st["by_exchange"]["binance"]["profit_loss_ratio"] is None
def test_list_daily_trades_search_filters_stats():
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "archive.db"
init_db(db)
day = "2023-11-15"
start_ms, _ = trading_day_bounds_ms(day)
btc_close = start_ms + 3_600_000
eth_close = start_ms + 7_200_000
upsert_trades_cache(
"gate",
[
{
"id": 1,
"symbol": "BTC/USDT",
"result": "止盈",
"pnl_amount": 5.0,
"opened_at_ms": start_ms,
"closed_at_ms": btc_close,
},
{
"id": 2,
"symbol": "ETH/USDT",
"result": "止损",
"pnl_amount": -2.0,
"opened_at_ms": btc_close,
"closed_at_ms": eth_close,
},
],
db_path=db,
)
payload = list_daily_trades(
period="range",
date_from=day,
date_to=day,
search="btc",
db_path=db,
)
assert len(payload["trades"]) == 1
assert payload["trades"][0]["symbol"] == "BTC/USDT"
st = payload["stats"]
assert st["open_count"] == 1
assert st["win_count"] == 1
assert st["loss_count"] == 0
assert st["max_win"] == 5.0
assert st["pnl_total"] == 5.0
+102 -102
View File
@@ -1,102 +1,102 @@
"""档案交易:strategy_trade_snapshots 补全 gate_bot 漏记。"""
from __future__ import annotations
import sqlite3
import tempfile
from datetime import datetime, timedelta
from pathlib import Path
from hub_trades_lib import fetch_trades_for_archive
def _init_db(path: Path) -> sqlite3.Connection:
conn = sqlite3.connect(str(path))
conn.row_factory = sqlite3.Row
conn.execute(
"""
CREATE TABLE trade_records (
id INTEGER PRIMARY KEY,
symbol TEXT,
direction TEXT,
result TEXT,
pnl_amount REAL,
opened_at TEXT,
closed_at TEXT,
opened_at_ms INTEGER,
closed_at_ms INTEGER,
created_at TEXT,
trend_plan_id INTEGER
)
"""
)
conn.execute(
"""
CREATE TABLE strategy_trade_snapshots (
id INTEGER PRIMARY KEY,
strategy_type TEXT,
source_id INTEGER,
symbol TEXT,
direction TEXT,
result_label TEXT,
status_at_close TEXT,
opened_at TEXT,
closed_at TEXT,
pnl_amount REAL,
snapshot_json TEXT,
created_at TEXT
)
"""
)
return conn
def test_merge_snapshot_when_trade_record_missing():
with tempfile.TemporaryDirectory() as td:
conn = _init_db(Path(td) / "t.db")
closed = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d %H:%M:%S")
conn.execute(
"""
INSERT INTO strategy_trade_snapshots (
id, strategy_type, source_id, symbol, direction,
result_label, opened_at, closed_at, pnl_amount, snapshot_json, created_at
) VALUES (?,?,?,?,?,?,?,?,?,?,?)
""",
(7, "trend_pullback", 42, "ONDO/USDT", "long", "止损", closed, closed, -1.2, "{}", closed),
)
conn.commit()
trades = fetch_trades_for_archive(conn, days=30, limit=50)
conn.close()
assert len(trades) == 1
assert trades[0]["symbol"] == "ONDO/USDT"
assert trades[0]["id"] == -7
assert trades[0].get("from_snapshot") is True
def test_skip_snapshot_when_trade_record_exists():
with tempfile.TemporaryDirectory() as td:
conn = _init_db(Path(td) / "t.db")
closed = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d %H:%M:%S")
conn.execute(
"""
INSERT INTO trade_records (
id, symbol, direction, result, pnl_amount,
opened_at, closed_at, opened_at_ms, closed_at_ms, created_at, trend_plan_id
) VALUES (?,?,?,?,?,?,?,?,?,?,?)
""",
(1, "ONDO/USDT", "long", "止损", -1.2, closed, closed, 1, 2, closed, 42),
)
conn.execute(
"""
INSERT INTO strategy_trade_snapshots (
id, strategy_type, source_id, symbol, direction,
result_label, opened_at, closed_at, pnl_amount, snapshot_json, created_at
) VALUES (?,?,?,?,?,?,?,?,?,?,?)
""",
(7, "trend_pullback", 42, "ONDO/USDT", "long", "止损", closed, closed, -1.2, "{}", closed),
)
conn.commit()
trades = fetch_trades_for_archive(conn, days=30, limit=50)
conn.close()
assert len(trades) == 1
assert trades[0]["id"] == 1
"""档案交易:strategy_trade_snapshots 补全 gate_bot 漏记。"""
from __future__ import annotations
import sqlite3
import tempfile
from datetime import datetime, timedelta
from pathlib import Path
from lib.hub.hub_trades_lib import fetch_trades_for_archive
def _init_db(path: Path) -> sqlite3.Connection:
conn = sqlite3.connect(str(path))
conn.row_factory = sqlite3.Row
conn.execute(
"""
CREATE TABLE trade_records (
id INTEGER PRIMARY KEY,
symbol TEXT,
direction TEXT,
result TEXT,
pnl_amount REAL,
opened_at TEXT,
closed_at TEXT,
opened_at_ms INTEGER,
closed_at_ms INTEGER,
created_at TEXT,
trend_plan_id INTEGER
)
"""
)
conn.execute(
"""
CREATE TABLE strategy_trade_snapshots (
id INTEGER PRIMARY KEY,
strategy_type TEXT,
source_id INTEGER,
symbol TEXT,
direction TEXT,
result_label TEXT,
status_at_close TEXT,
opened_at TEXT,
closed_at TEXT,
pnl_amount REAL,
snapshot_json TEXT,
created_at TEXT
)
"""
)
return conn
def test_merge_snapshot_when_trade_record_missing():
with tempfile.TemporaryDirectory() as td:
conn = _init_db(Path(td) / "t.db")
closed = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d %H:%M:%S")
conn.execute(
"""
INSERT INTO strategy_trade_snapshots (
id, strategy_type, source_id, symbol, direction,
result_label, opened_at, closed_at, pnl_amount, snapshot_json, created_at
) VALUES (?,?,?,?,?,?,?,?,?,?,?)
""",
(7, "trend_pullback", 42, "ONDO/USDT", "long", "止损", closed, closed, -1.2, "{}", closed),
)
conn.commit()
trades = fetch_trades_for_archive(conn, days=30, limit=50)
conn.close()
assert len(trades) == 1
assert trades[0]["symbol"] == "ONDO/USDT"
assert trades[0]["id"] == -7
assert trades[0].get("from_snapshot") is True
def test_skip_snapshot_when_trade_record_exists():
with tempfile.TemporaryDirectory() as td:
conn = _init_db(Path(td) / "t.db")
closed = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d %H:%M:%S")
conn.execute(
"""
INSERT INTO trade_records (
id, symbol, direction, result, pnl_amount,
opened_at, closed_at, opened_at_ms, closed_at_ms, created_at, trend_plan_id
) VALUES (?,?,?,?,?,?,?,?,?,?,?)
""",
(1, "ONDO/USDT", "long", "止损", -1.2, closed, closed, 1, 2, closed, 42),
)
conn.execute(
"""
INSERT INTO strategy_trade_snapshots (
id, strategy_type, source_id, symbol, direction,
result_label, opened_at, closed_at, pnl_amount, snapshot_json, created_at
) VALUES (?,?,?,?,?,?,?,?,?,?,?)
""",
(7, "trend_pullback", 42, "ONDO/USDT", "long", "止损", closed, closed, -1.2, "{}", closed),
)
conn.commit()
trades = fetch_trades_for_archive(conn, days=30, limit=50)
conn.close()
assert len(trades) == 1
assert trades[0]["id"] == 1
+198 -198
View File
@@ -1,198 +1,198 @@
"""hub_trades_lib 单元测试。"""
from __future__ import annotations
import sqlite3
import unittest
from datetime import datetime
from hub_trades_lib import (
fetch_trades_for_trading_day,
summarize_trades,
trading_day_from_dt,
trading_day_window_bounds,
)
class HubTradesLibTest(unittest.TestCase):
def test_trading_day_reset(self):
dt = datetime(2026, 6, 6, 7, 30, 0)
self.assertEqual(trading_day_from_dt(dt, 8), "2026-06-05")
dt2 = datetime(2026, 6, 6, 8, 0, 0)
self.assertEqual(trading_day_from_dt(dt2, 8), "2026-06-06")
def test_trading_day_window_bounds(self):
start, end = trading_day_window_bounds("2026-06-06", 8)
self.assertEqual(start, "2026-06-06 08:00:00")
self.assertEqual(end, "2026-06-07 07:59:59")
def test_fetch_and_summarize(self):
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
conn.execute(
"""CREATE TABLE trade_records (
symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT,
pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL,
closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT,
created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL,
trade_style TEXT, entry_reason TEXT, reviewed_at TEXT
)"""
)
conn.execute(
"INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
(
"ONDO/USDT",
"short",
"止损",
None,
-0.5,
None,
None,
"2026-06-06 10:00:00",
None,
"2026-06-06 09:00:00",
None,
"2026-06-06 10:00:00",
"趋势回调",
None,
None,
"trend",
"",
None,
),
)
conn.commit()
rows = fetch_trades_for_trading_day(conn, "2026-06-06")
self.assertEqual(len(rows), 1)
stats = summarize_trades(rows)
self.assertEqual(stats["closed_count"], 1)
self.assertEqual(stats["loss_count"], 1)
self.assertAlmostEqual(stats["total_pnl_u"], -0.5)
conn.close()
def test_early_morning_belongs_prev_trading_day(self):
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
conn.execute(
"""CREATE TABLE trade_records (
symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT,
pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL,
closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT,
created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL,
trade_style TEXT, entry_reason TEXT, reviewed_at TEXT
)"""
)
conn.execute(
"INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
(
"BTC/USDT",
"long",
"止盈",
None,
1.2,
None,
None,
"2026-06-07 07:30:00",
None,
"2026-06-07 06:00:00",
None,
"2026-06-07 07:30:00",
"关键位",
None,
None,
"trend",
"",
None,
),
)
conn.commit()
self.assertEqual(len(fetch_trades_for_trading_day(conn, "2026-06-07")), 0)
self.assertEqual(len(fetch_trades_for_trading_day(conn, "2026-06-06")), 1)
conn.close()
def test_reviewed_fields_preferred(self):
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
conn.execute(
"""CREATE TABLE trade_records (
symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT,
pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL,
closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT,
created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL,
trade_style TEXT, entry_reason TEXT, reviewed_at TEXT
)"""
)
conn.execute(
"INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
(
"ETH/USDT",
"long",
"止损",
"止盈",
-0.5,
2.0,
None,
"2026-06-06 09:00:00",
"2026-06-06 11:00:00",
"2026-06-06 08:00:00",
None,
"2026-06-06 11:00:00",
"趋势回调",
None,
None,
"trend",
"",
"2026-06-06 12:00:00",
),
)
conn.commit()
rows = fetch_trades_for_trading_day(conn, "2026-06-06")
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["result"], "止盈")
self.assertAlmostEqual(rows[0]["pnl_amount"], 2.0)
self.assertTrue(rows[0]["reviewed"])
conn.close()
def test_time_close_result_included(self):
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
conn.execute(
"""CREATE TABLE trade_records (
symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT,
pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL,
closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT,
created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL,
trade_style TEXT, entry_reason TEXT, reviewed_at TEXT
)"""
)
conn.execute(
"INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
(
"BTC/USDT",
"long",
"时间平仓",
None,
1.2,
None,
None,
"2026-06-06 12:00:00",
None,
"2026-06-06 08:00:00",
None,
"2026-06-06 12:00:00",
"趋势回调",
None,
None,
"trend",
"",
None,
),
)
conn.commit()
rows = fetch_trades_for_trading_day(conn, "2026-06-06")
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["result"], "时间平仓")
conn.close()
if __name__ == "__main__":
unittest.main()
"""hub_trades_lib 单元测试。"""
from __future__ import annotations
import sqlite3
import unittest
from datetime import datetime
from lib.hub.hub_trades_lib import (
fetch_trades_for_trading_day,
summarize_trades,
trading_day_from_dt,
trading_day_window_bounds,
)
class HubTradesLibTest(unittest.TestCase):
def test_trading_day_reset(self):
dt = datetime(2026, 6, 6, 7, 30, 0)
self.assertEqual(trading_day_from_dt(dt, 8), "2026-06-05")
dt2 = datetime(2026, 6, 6, 8, 0, 0)
self.assertEqual(trading_day_from_dt(dt2, 8), "2026-06-06")
def test_trading_day_window_bounds(self):
start, end = trading_day_window_bounds("2026-06-06", 8)
self.assertEqual(start, "2026-06-06 08:00:00")
self.assertEqual(end, "2026-06-07 07:59:59")
def test_fetch_and_summarize(self):
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
conn.execute(
"""CREATE TABLE trade_records (
symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT,
pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL,
closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT,
created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL,
trade_style TEXT, entry_reason TEXT, reviewed_at TEXT
)"""
)
conn.execute(
"INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
(
"ONDO/USDT",
"short",
"止损",
None,
-0.5,
None,
None,
"2026-06-06 10:00:00",
None,
"2026-06-06 09:00:00",
None,
"2026-06-06 10:00:00",
"趋势回调",
None,
None,
"trend",
"",
None,
),
)
conn.commit()
rows = fetch_trades_for_trading_day(conn, "2026-06-06")
self.assertEqual(len(rows), 1)
stats = summarize_trades(rows)
self.assertEqual(stats["closed_count"], 1)
self.assertEqual(stats["loss_count"], 1)
self.assertAlmostEqual(stats["total_pnl_u"], -0.5)
conn.close()
def test_early_morning_belongs_prev_trading_day(self):
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
conn.execute(
"""CREATE TABLE trade_records (
symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT,
pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL,
closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT,
created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL,
trade_style TEXT, entry_reason TEXT, reviewed_at TEXT
)"""
)
conn.execute(
"INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
(
"BTC/USDT",
"long",
"止盈",
None,
1.2,
None,
None,
"2026-06-07 07:30:00",
None,
"2026-06-07 06:00:00",
None,
"2026-06-07 07:30:00",
"关键位",
None,
None,
"trend",
"",
None,
),
)
conn.commit()
self.assertEqual(len(fetch_trades_for_trading_day(conn, "2026-06-07")), 0)
self.assertEqual(len(fetch_trades_for_trading_day(conn, "2026-06-06")), 1)
conn.close()
def test_reviewed_fields_preferred(self):
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
conn.execute(
"""CREATE TABLE trade_records (
symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT,
pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL,
closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT,
created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL,
trade_style TEXT, entry_reason TEXT, reviewed_at TEXT
)"""
)
conn.execute(
"INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
(
"ETH/USDT",
"long",
"止损",
"止盈",
-0.5,
2.0,
None,
"2026-06-06 09:00:00",
"2026-06-06 11:00:00",
"2026-06-06 08:00:00",
None,
"2026-06-06 11:00:00",
"趋势回调",
None,
None,
"trend",
"",
"2026-06-06 12:00:00",
),
)
conn.commit()
rows = fetch_trades_for_trading_day(conn, "2026-06-06")
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["result"], "止盈")
self.assertAlmostEqual(rows[0]["pnl_amount"], 2.0)
self.assertTrue(rows[0]["reviewed"])
conn.close()
def test_time_close_result_included(self):
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
conn.execute(
"""CREATE TABLE trade_records (
symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT,
pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL,
closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT,
created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL,
trade_style TEXT, entry_reason TEXT, reviewed_at TEXT
)"""
)
conn.execute(
"INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
(
"BTC/USDT",
"long",
"时间平仓",
None,
1.2,
None,
None,
"2026-06-06 12:00:00",
None,
"2026-06-06 08:00:00",
None,
"2026-06-06 12:00:00",
"趋势回调",
None,
None,
"trend",
"",
None,
),
)
conn.commit()
rows = fetch_trades_for_trading_day(conn, "2026-06-06")
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["result"], "时间平仓")
conn.close()
if __name__ == "__main__":
unittest.main()
+115 -115
View File
@@ -1,115 +1,115 @@
"""档案交易:复盘字段优先(开仓类型、持仓时长、开平仓时间)。"""
from __future__ import annotations
import tempfile
import unittest
from datetime import datetime, timedelta
from pathlib import Path
from hub_symbol_archive_lib import init_db, load_symbol_trades, upsert_trades_cache
from hub_trades_lib import (
_normalize_archive_trade_row,
display_entry_type_label,
effective_entry_type,
effective_hold_minutes,
)
class TestHubTradesReviewFields(unittest.TestCase):
def test_display_entry_type_for_manual_monitor_review(self):
d = {
"monitor_type": "下单监控",
"entry_reason": "",
"reviewed_entry_reason": "突破回踩",
"reviewed_at": "2026-06-08 10:00:00",
}
self.assertEqual(display_entry_type_label(d), "突破回踩")
def test_effective_entry_type_prefers_reviewed(self):
d = {
"entry_reason": "突破回踩",
"reviewed_entry_reason": "趋势回调",
"monitor_type": "下单监控",
}
self.assertEqual(effective_entry_type(d), "趋势回调")
def test_effective_hold_minutes_prefers_reviewed(self):
d = {
"hold_minutes": 30,
"reviewed_hold_minutes": 95,
"opened_at_ms": 1_700_000_000_000,
"closed_at_ms": 1_700_001_800_000,
}
self.assertEqual(effective_hold_minutes(d), 95)
def test_normalize_archive_trade_row_review_fields(self):
closed = (datetime.now() - timedelta(days=2)).strftime("%Y-%m-%d %H:%M:%S")
opened = (datetime.now() - timedelta(days=2, hours=2)).strftime("%Y-%m-%d %H:%M:%S")
row = _normalize_archive_trade_row(
{
"id": 9,
"symbol": "ONDO/USDT",
"direction": "short",
"result": "止损",
"reviewed_result": "手动平仓",
"pnl_amount": -2.5,
"reviewed_pnl_amount": -2.58,
"opened_at": opened,
"reviewed_opened_at": "2026-06-07 14:30:00",
"closed_at": closed,
"reviewed_closed_at": "2026-06-08 08:44:21",
"opened_at_ms": 1_700_000_000_000,
"closed_at_ms": 1_700_007_200_000,
"entry_reason": "突破回踩",
"reviewed_entry_reason": "趋势回调",
"hold_minutes": 30,
"reviewed_hold_minutes": 1080,
"monitor_type": "趋势回调",
"reviewed_at": closed,
},
exchange_key="gate",
)
self.assertIsNotNone(row)
assert row is not None
self.assertEqual(row["entry_type"], "趋势回调")
self.assertEqual(row["hold_minutes"], 1080)
self.assertEqual(row["opened_at"], "2026-06-07 14:30:00")
self.assertEqual(row["closed_at"], "2026-06-08 08:44:21")
self.assertTrue(row["reviewed"])
def test_archive_cache_enriches_review_display_fields(self):
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "archive.db"
init_db(db)
upsert_trades_cache(
"gate",
[
{
"id": 3,
"symbol": "ONDO/USDT",
"direction": "short",
"result": "手动平仓",
"pnl_amount": -2.58,
"opened_at": "2026-06-07 14:30:00",
"closed_at": "2026-06-08 08:44:21",
"opened_at_ms": 1_781_000_000_000,
"closed_at_ms": 1_781_065_000_000,
"entry_type": "趋势回调",
"hold_minutes": 1080,
"hold_minutes_text": "18小时0分钟",
"reviewed": True,
}
],
db_path=db,
)
rows = load_symbol_trades("gate", "ONDO/USDT", db_path=db)
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["entry_type"], "趋势回调")
self.assertEqual(rows[0]["hold_minutes"], 1080)
self.assertTrue(rows[0]["opened_at"].startswith("2026-06-07"))
self.assertTrue(rows[0]["closed_at"].startswith("2026-06-08"))
if __name__ == "__main__":
unittest.main()
"""档案交易:复盘字段优先(开仓类型、持仓时长、开平仓时间)。"""
from __future__ import annotations
import tempfile
import unittest
from datetime import datetime, timedelta
from pathlib import Path
from lib.hub.hub_symbol_archive_lib import init_db, load_symbol_trades, upsert_trades_cache
from lib.hub.hub_trades_lib import (
_normalize_archive_trade_row,
display_entry_type_label,
effective_entry_type,
effective_hold_minutes,
)
class TestHubTradesReviewFields(unittest.TestCase):
def test_display_entry_type_for_manual_monitor_review(self):
d = {
"monitor_type": "下单监控",
"entry_reason": "",
"reviewed_entry_reason": "突破回踩",
"reviewed_at": "2026-06-08 10:00:00",
}
self.assertEqual(display_entry_type_label(d), "突破回踩")
def test_effective_entry_type_prefers_reviewed(self):
d = {
"entry_reason": "突破回踩",
"reviewed_entry_reason": "趋势回调",
"monitor_type": "下单监控",
}
self.assertEqual(effective_entry_type(d), "趋势回调")
def test_effective_hold_minutes_prefers_reviewed(self):
d = {
"hold_minutes": 30,
"reviewed_hold_minutes": 95,
"opened_at_ms": 1_700_000_000_000,
"closed_at_ms": 1_700_001_800_000,
}
self.assertEqual(effective_hold_minutes(d), 95)
def test_normalize_archive_trade_row_review_fields(self):
closed = (datetime.now() - timedelta(days=2)).strftime("%Y-%m-%d %H:%M:%S")
opened = (datetime.now() - timedelta(days=2, hours=2)).strftime("%Y-%m-%d %H:%M:%S")
row = _normalize_archive_trade_row(
{
"id": 9,
"symbol": "ONDO/USDT",
"direction": "short",
"result": "止损",
"reviewed_result": "手动平仓",
"pnl_amount": -2.5,
"reviewed_pnl_amount": -2.58,
"opened_at": opened,
"reviewed_opened_at": "2026-06-07 14:30:00",
"closed_at": closed,
"reviewed_closed_at": "2026-06-08 08:44:21",
"opened_at_ms": 1_700_000_000_000,
"closed_at_ms": 1_700_007_200_000,
"entry_reason": "突破回踩",
"reviewed_entry_reason": "趋势回调",
"hold_minutes": 30,
"reviewed_hold_minutes": 1080,
"monitor_type": "趋势回调",
"reviewed_at": closed,
},
exchange_key="gate",
)
self.assertIsNotNone(row)
assert row is not None
self.assertEqual(row["entry_type"], "趋势回调")
self.assertEqual(row["hold_minutes"], 1080)
self.assertEqual(row["opened_at"], "2026-06-07 14:30:00")
self.assertEqual(row["closed_at"], "2026-06-08 08:44:21")
self.assertTrue(row["reviewed"])
def test_archive_cache_enriches_review_display_fields(self):
with tempfile.TemporaryDirectory() as td:
db = Path(td) / "archive.db"
init_db(db)
upsert_trades_cache(
"gate",
[
{
"id": 3,
"symbol": "ONDO/USDT",
"direction": "short",
"result": "手动平仓",
"pnl_amount": -2.58,
"opened_at": "2026-06-07 14:30:00",
"closed_at": "2026-06-08 08:44:21",
"opened_at_ms": 1_781_000_000_000,
"closed_at_ms": 1_781_065_000_000,
"entry_type": "趋势回调",
"hold_minutes": 1080,
"hold_minutes_text": "18小时0分钟",
"reviewed": True,
}
],
db_path=db,
)
rows = load_symbol_trades("gate", "ONDO/USDT", db_path=db)
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["entry_type"], "趋势回调")
self.assertEqual(rows[0]["hold_minutes"], 1080)
self.assertTrue(rows[0]["opened_at"].startswith("2026-06-07"))
self.assertTrue(rows[0]["closed_at"].startswith("2026-06-08"))
if __name__ == "__main__":
unittest.main()
+184 -184
View File
@@ -1,184 +1,184 @@
from datetime import datetime
from unittest.mock import MagicMock
from hub_volume_rank_lib import (
CACHE_VERSION,
LIQUIDITY_RANK_CACHE_VERSION,
TOP_N_DEFAULT,
_exchange_rank_row_stale,
_okx_turnover_usdt,
_scores_from_binance,
_scores_from_gate,
build_usdt_swap_volume_ranks,
cache_needs_refresh,
format_volume_quote,
merge_exchange_rank,
rank_date_label,
resolve_daily_volume_rank,
)
def test_rank_date_label_after_reset():
# 2026-06-08 09:00 北京时间 → 昨日交易日 2026-06-07
dt = datetime(2026, 6, 8, 9, 0, 0)
assert rank_date_label(now=dt, reset_hour=8) == "2026-06-07"
def test_rank_date_label_before_reset():
# 2026-06-08 07:00 → 当前交易日仍算 2026-06-07,昨日为 2026-06-06
dt = datetime(2026, 6, 8, 7, 0, 0)
assert rank_date_label(now=dt, reset_hour=8) == "2026-06-06"
def test_format_volume_quote():
assert format_volume_quote(1_500_000_000) == "1.50B"
assert format_volume_quote(2_300_000) == "2.30M"
assert format_volume_quote(4500) == "4.50K"
def test_okx_turnover_usdt():
qv = _okx_turnover_usdt({"volCcy24h": "100", "last": "50"})
assert qv == 5000.0
def test_cache_needs_refresh_and_merge():
cache = {"rank_date": "2026-06-05", "exchanges": {}}
assert cache_needs_refresh(cache, expected_rank_date="2026-06-07") is True
merged = merge_exchange_rank(
cache,
"binance",
{
"ok": True,
"rank_date": "2026-06-07",
"items": [{"rank": 1, "symbol": "BTC/USDT", "volume_quote": 1.0}],
"total_symbols": 100,
},
)
assert merged["exchanges"]["binance"]["items"][0]["symbol"] == "BTC/USDT"
assert merged["rank_date"] == "2026-06-07"
def test_stale_cache_version_forces_refresh():
cache = {"version": CACHE_VERSION - 1, "rank_date": "2026-06-07", "exchanges": {"okx": {"items": [{}]}}}
assert cache_needs_refresh(cache) is True
def test_short_item_list_is_stale():
items = [{"rank": i, "symbol": f"S{i}/USDT"} for i in range(1, 13)]
row = {"items": items, "total_symbols": 12}
assert _exchange_rank_row_stale(row) is True
full = {"items": items + [{"rank": i, "symbol": f"X{i}/USDT"} for i in range(13, TOP_N_DEFAULT + 1)], "total_symbols": 300}
assert _exchange_rank_row_stale(full) is False
def test_scores_from_binance_uses_fapi_lightweight_api():
ex = MagicMock()
ex.id = "binance"
ex.fapiPublicGetTicker24hr.return_value = [
{"symbol": "BTCUSDT", "quoteVolume": "9000000"},
{"symbol": "ETHUSDT", "quoteVolume": "5000000"},
]
scored = _scores_from_binance(ex)
assert scored[0][1] == "BTC"
assert scored[0][2] == 9000000.0
ex.fetch_tickers.assert_not_called()
def test_scores_from_binance_skips_fetch_tickers_on_api_error():
ex = MagicMock()
ex.id = "binance"
ex.fapiPublicGetTicker24hr.side_effect = RuntimeError("network")
scored = _scores_from_binance(ex)
assert scored == []
ex.fetch_tickers.assert_not_called()
def test_scores_from_gate_uses_futures_tickers_api():
ex = MagicMock()
ex.id = "gateio"
ex.publicFuturesGetSettleTickers.return_value = [
{"contract": "BTC_USDT", "volume_24h_quote": "8000000"},
{"contract": "ETH_USDT", "volume_24h_quote": "4000000"},
]
scored = _scores_from_gate(ex)
assert scored[0][1] == "BTC"
ex.fetch_tickers.assert_not_called()
def test_scores_from_gate_skips_fetch_tickers_on_api_error():
ex = MagicMock()
ex.id = "gateio"
ex.publicFuturesGetSettleTickers.side_effect = RuntimeError("network")
scored = _scores_from_gate(ex)
assert scored == []
ex.fetch_tickers.assert_not_called()
def test_resolve_daily_volume_rank_caches_result():
cache = {"version": 0, "updated_at": 0.0, "ranks": {}, "total": 0}
ex = MagicMock()
ex.id = "binance"
ex.fapiPublicGetTicker24hr.return_value = [
{"symbol": "BTCUSDT", "quoteVolume": "100"},
{"symbol": "ETHUSDT", "quoteVolume": "50"},
]
rank, total = resolve_daily_volume_rank(
"BTC",
cache,
now_ts=1000.0,
ttl_sec=60.0,
exchange=ex,
ensure_markets_loaded=lambda: None,
)
assert rank == 1
assert total == 2
assert cache["version"] == LIQUIDITY_RANK_CACHE_VERSION
calls = ex.fapiPublicGetTicker24hr.call_count
rank2, _ = resolve_daily_volume_rank(
"BTC",
cache,
now_ts=1010.0,
ttl_sec=60.0,
exchange=ex,
ensure_markets_loaded=lambda: None,
)
assert rank2 == 1
assert ex.fapiPublicGetTicker24hr.call_count == calls
def test_resolve_daily_volume_rank_keeps_stale_cache_when_refresh_empty():
cache = {
"version": LIQUIDITY_RANK_CACHE_VERSION,
"updated_at": 900.0,
"ranks": {"BTC": 1},
"total": 100,
}
ex = MagicMock()
ex.id = "binance"
ex.fapiPublicGetTicker24hr.return_value = []
rank, total = resolve_daily_volume_rank(
"BTC",
cache,
now_ts=2000.0,
ttl_sec=60.0,
exchange=ex,
ensure_markets_loaded=lambda: None,
)
assert rank == 1
assert total == 100
assert cache["updated_at"] == 900.0
ex.fetch_tickers.assert_not_called()
def test_build_usdt_swap_volume_ranks():
ex = MagicMock()
ex.id = "binance"
ex.fapiPublicGetTicker24hr.return_value = [
{"symbol": "SOLUSDT", "quoteVolume": "200"},
]
ranks, total = build_usdt_swap_volume_ranks(ex, lambda: None)
assert ranks["SOL"] == 1
assert total == 1
from datetime import datetime
from unittest.mock import MagicMock
from lib.hub.hub_volume_rank_lib import (
CACHE_VERSION,
LIQUIDITY_RANK_CACHE_VERSION,
TOP_N_DEFAULT,
_exchange_rank_row_stale,
_okx_turnover_usdt,
_scores_from_binance,
_scores_from_gate,
build_usdt_swap_volume_ranks,
cache_needs_refresh,
format_volume_quote,
merge_exchange_rank,
rank_date_label,
resolve_daily_volume_rank,
)
def test_rank_date_label_after_reset():
# 2026-06-08 09:00 北京时间 → 昨日交易日 2026-06-07
dt = datetime(2026, 6, 8, 9, 0, 0)
assert rank_date_label(now=dt, reset_hour=8) == "2026-06-07"
def test_rank_date_label_before_reset():
# 2026-06-08 07:00 → 当前交易日仍算 2026-06-07,昨日为 2026-06-06
dt = datetime(2026, 6, 8, 7, 0, 0)
assert rank_date_label(now=dt, reset_hour=8) == "2026-06-06"
def test_format_volume_quote():
assert format_volume_quote(1_500_000_000) == "1.50B"
assert format_volume_quote(2_300_000) == "2.30M"
assert format_volume_quote(4500) == "4.50K"
def test_okx_turnover_usdt():
qv = _okx_turnover_usdt({"volCcy24h": "100", "last": "50"})
assert qv == 5000.0
def test_cache_needs_refresh_and_merge():
cache = {"rank_date": "2026-06-05", "exchanges": {}}
assert cache_needs_refresh(cache, expected_rank_date="2026-06-07") is True
merged = merge_exchange_rank(
cache,
"binance",
{
"ok": True,
"rank_date": "2026-06-07",
"items": [{"rank": 1, "symbol": "BTC/USDT", "volume_quote": 1.0}],
"total_symbols": 100,
},
)
assert merged["exchanges"]["binance"]["items"][0]["symbol"] == "BTC/USDT"
assert merged["rank_date"] == "2026-06-07"
def test_stale_cache_version_forces_refresh():
cache = {"version": CACHE_VERSION - 1, "rank_date": "2026-06-07", "exchanges": {"okx": {"items": [{}]}}}
assert cache_needs_refresh(cache) is True
def test_short_item_list_is_stale():
items = [{"rank": i, "symbol": f"S{i}/USDT"} for i in range(1, 13)]
row = {"items": items, "total_symbols": 12}
assert _exchange_rank_row_stale(row) is True
full = {"items": items + [{"rank": i, "symbol": f"X{i}/USDT"} for i in range(13, TOP_N_DEFAULT + 1)], "total_symbols": 300}
assert _exchange_rank_row_stale(full) is False
def test_scores_from_binance_uses_fapi_lightweight_api():
ex = MagicMock()
ex.id = "binance"
ex.fapiPublicGetTicker24hr.return_value = [
{"symbol": "BTCUSDT", "quoteVolume": "9000000"},
{"symbol": "ETHUSDT", "quoteVolume": "5000000"},
]
scored = _scores_from_binance(ex)
assert scored[0][1] == "BTC"
assert scored[0][2] == 9000000.0
ex.fetch_tickers.assert_not_called()
def test_scores_from_binance_skips_fetch_tickers_on_api_error():
ex = MagicMock()
ex.id = "binance"
ex.fapiPublicGetTicker24hr.side_effect = RuntimeError("network")
scored = _scores_from_binance(ex)
assert scored == []
ex.fetch_tickers.assert_not_called()
def test_scores_from_gate_uses_futures_tickers_api():
ex = MagicMock()
ex.id = "gateio"
ex.publicFuturesGetSettleTickers.return_value = [
{"contract": "BTC_USDT", "volume_24h_quote": "8000000"},
{"contract": "ETH_USDT", "volume_24h_quote": "4000000"},
]
scored = _scores_from_gate(ex)
assert scored[0][1] == "BTC"
ex.fetch_tickers.assert_not_called()
def test_scores_from_gate_skips_fetch_tickers_on_api_error():
ex = MagicMock()
ex.id = "gateio"
ex.publicFuturesGetSettleTickers.side_effect = RuntimeError("network")
scored = _scores_from_gate(ex)
assert scored == []
ex.fetch_tickers.assert_not_called()
def test_resolve_daily_volume_rank_caches_result():
cache = {"version": 0, "updated_at": 0.0, "ranks": {}, "total": 0}
ex = MagicMock()
ex.id = "binance"
ex.fapiPublicGetTicker24hr.return_value = [
{"symbol": "BTCUSDT", "quoteVolume": "100"},
{"symbol": "ETHUSDT", "quoteVolume": "50"},
]
rank, total = resolve_daily_volume_rank(
"BTC",
cache,
now_ts=1000.0,
ttl_sec=60.0,
exchange=ex,
ensure_markets_loaded=lambda: None,
)
assert rank == 1
assert total == 2
assert cache["version"] == LIQUIDITY_RANK_CACHE_VERSION
calls = ex.fapiPublicGetTicker24hr.call_count
rank2, _ = resolve_daily_volume_rank(
"BTC",
cache,
now_ts=1010.0,
ttl_sec=60.0,
exchange=ex,
ensure_markets_loaded=lambda: None,
)
assert rank2 == 1
assert ex.fapiPublicGetTicker24hr.call_count == calls
def test_resolve_daily_volume_rank_keeps_stale_cache_when_refresh_empty():
cache = {
"version": LIQUIDITY_RANK_CACHE_VERSION,
"updated_at": 900.0,
"ranks": {"BTC": 1},
"total": 100,
}
ex = MagicMock()
ex.id = "binance"
ex.fapiPublicGetTicker24hr.return_value = []
rank, total = resolve_daily_volume_rank(
"BTC",
cache,
now_ts=2000.0,
ttl_sec=60.0,
exchange=ex,
ensure_markets_loaded=lambda: None,
)
assert rank == 1
assert total == 100
assert cache["updated_at"] == 900.0
ex.fetch_tickers.assert_not_called()
def test_build_usdt_swap_volume_ranks():
ex = MagicMock()
ex.id = "binance"
ex.fapiPublicGetTicker24hr.return_value = [
{"symbol": "SOLUSDT", "quoteVolume": "200"},
]
ranks, total = build_usdt_swap_volume_ranks(ex, lambda: None)
assert ranks["SOL"] == 1
assert total == 1
+28 -28
View File
@@ -1,28 +1,28 @@
from instance_embed_context_lib import embed_render_plan, trade_records_summary
def test_embed_fragment_trade_is_light():
plan = embed_render_plan("trade", "fragment")
assert plan.exchange_capitals is False
assert plan.records_rows is False
assert plan.records_summary is False
assert plan.orders is True
assert plan.key_history is False
def test_embed_shell_trade_summary_only():
plan = embed_render_plan("trade", "shell")
assert plan.exchange_capitals is True
assert plan.records_summary is True
assert plan.records_rows is False
def test_embed_records_page_loads_rows():
plan = embed_render_plan("records", "fragment")
assert plan.records_rows is True
def test_full_page_unchanged():
plan = embed_render_plan("trade", None)
assert plan.records_rows is True
assert plan.exchange_capitals is True
from lib.instance.instance_embed_context_lib import embed_render_plan, trade_records_summary
def test_embed_fragment_trade_is_light():
plan = embed_render_plan("trade", "fragment")
assert plan.exchange_capitals is False
assert plan.records_rows is False
assert plan.records_summary is False
assert plan.orders is True
assert plan.key_history is False
def test_embed_shell_trade_summary_only():
plan = embed_render_plan("trade", "shell")
assert plan.exchange_capitals is True
assert plan.records_summary is True
assert plan.records_rows is False
def test_embed_records_page_loads_rows():
plan = embed_render_plan("records", "fragment")
assert plan.records_rows is True
def test_full_page_unchanged():
plan = embed_render_plan("trade", None)
assert plan.records_rows is True
assert plan.exchange_capitals is True
+26 -26
View File
@@ -1,26 +1,26 @@
from instance_embed_lib import (
EMBED_TABS,
path_to_embed_tab,
rewrite_embed_dest,
)
def test_path_to_embed_tab():
assert path_to_embed_tab("/trade") == "trade"
assert path_to_embed_tab("/key_monitor") == "key_monitor"
assert path_to_embed_tab("/strategy/records") == "strategy_records"
assert path_to_embed_tab("/unknown") is None
def test_rewrite_embed_dest():
url = rewrite_embed_dest("/trade", hub_theme="dark")
assert url.startswith("/embed?")
assert "tab=trade" in url
assert "embed=1" in url
assert "hub_theme=dark" in url
def test_embed_tabs_cover_main_nav():
assert "trade" in EMBED_TABS
assert "key_monitor" in EMBED_TABS
assert "records" in EMBED_TABS
from lib.instance.instance_embed_lib import (
EMBED_TABS,
path_to_embed_tab,
rewrite_embed_dest,
)
def test_path_to_embed_tab():
assert path_to_embed_tab("/trade") == "trade"
assert path_to_embed_tab("/key_monitor") == "key_monitor"
assert path_to_embed_tab("/strategy/records") == "strategy_records"
assert path_to_embed_tab("/unknown") is None
def test_rewrite_embed_dest():
url = rewrite_embed_dest("/trade", hub_theme="dark")
assert url.startswith("/embed?")
assert "tab=trade" in url
assert "embed=1" in url
assert "hub_theme=dark" in url
def test_embed_tabs_cover_main_nav():
assert "trade" in EMBED_TABS
assert "key_monitor" in EMBED_TABS
assert "records" in EMBED_TABS
+21 -21
View File
@@ -1,21 +1,21 @@
from instance_nav_lib import request_is_hub_soft_nav
def test_request_is_hub_soft_nav():
class Req:
args = {"embed": "1"}
headers = {"X-Instance-Soft-Nav": "1"}
assert request_is_hub_soft_nav(Req()) is True
class Req2:
args = {"embed": "1"}
headers = {}
assert request_is_hub_soft_nav(Req2()) is False
class Req3:
args = {}
headers = {"X-Instance-Soft-Nav": "1"}
assert request_is_hub_soft_nav(Req3()) is False
from lib.instance.instance_nav_lib import request_is_hub_soft_nav
def test_request_is_hub_soft_nav():
class Req:
args = {"embed": "1"}
headers = {"X-Instance-Soft-Nav": "1"}
assert request_is_hub_soft_nav(Req()) is True
class Req2:
args = {"embed": "1"}
headers = {}
assert request_is_hub_soft_nav(Req2()) is False
class Req3:
args = {}
headers = {"X-Instance-Soft-Nav": "1"}
assert request_is_hub_soft_nav(Req3()) is False
+34 -34
View File
@@ -1,34 +1,34 @@
import unittest
from key_monitor_lib import (
BOX_BREAKOUT_CLOSE_OPPOSITE,
box_breakout_invalidate_by_mark,
box_breakout_invalidate_edge_label,
)
class BoxBreakoutInvalidateTests(unittest.TestCase):
def test_short_invalidates_above_upper(self):
self.assertTrue(box_breakout_invalidate_by_mark("short", 62.511, 61.746, 60.569))
def test_short_stays_valid_inside_or_below(self):
self.assertFalse(box_breakout_invalidate_by_mark("short", 61.0, 61.746, 60.569))
self.assertFalse(box_breakout_invalidate_by_mark("short", 60.0, 61.746, 60.569))
def test_long_invalidates_below_lower(self):
self.assertTrue(box_breakout_invalidate_by_mark("long", 94.0, 100.0, 95.0))
def test_long_stays_valid_inside_or_above(self):
self.assertFalse(box_breakout_invalidate_by_mark("long", 98.0, 100.0, 95.0))
self.assertFalse(box_breakout_invalidate_by_mark("long", 101.0, 100.0, 95.0))
def test_edge_label(self):
self.assertEqual(box_breakout_invalidate_edge_label("long"), "下沿")
self.assertEqual(box_breakout_invalidate_edge_label("short"), "上沿")
def test_close_reason_constant(self):
self.assertEqual(BOX_BREAKOUT_CLOSE_OPPOSITE, "box_opposite_break")
if __name__ == "__main__":
unittest.main()
import unittest
from lib.key_monitor.key_monitor_lib import (
BOX_BREAKOUT_CLOSE_OPPOSITE,
box_breakout_invalidate_by_mark,
box_breakout_invalidate_edge_label,
)
class BoxBreakoutInvalidateTests(unittest.TestCase):
def test_short_invalidates_above_upper(self):
self.assertTrue(box_breakout_invalidate_by_mark("short", 62.511, 61.746, 60.569))
def test_short_stays_valid_inside_or_below(self):
self.assertFalse(box_breakout_invalidate_by_mark("short", 61.0, 61.746, 60.569))
self.assertFalse(box_breakout_invalidate_by_mark("short", 60.0, 61.746, 60.569))
def test_long_invalidates_below_lower(self):
self.assertTrue(box_breakout_invalidate_by_mark("long", 94.0, 100.0, 95.0))
def test_long_stays_valid_inside_or_above(self):
self.assertFalse(box_breakout_invalidate_by_mark("long", 98.0, 100.0, 95.0))
self.assertFalse(box_breakout_invalidate_by_mark("long", 101.0, 100.0, 95.0))
def test_edge_label(self):
self.assertEqual(box_breakout_invalidate_edge_label("long"), "下沿")
self.assertEqual(box_breakout_invalidate_edge_label("short"), "上沿")
def test_close_reason_constant(self):
self.assertEqual(BOX_BREAKOUT_CLOSE_OPPOSITE, "box_opposite_break")
if __name__ == "__main__":
unittest.main()
+86 -86
View File
@@ -1,86 +1,86 @@
"""阻力/支撑提醒:占位与间隔防重复推送。"""
from __future__ import annotations
import sqlite3
import unittest
from datetime import datetime, timedelta
from key_monitor_lib import (
claim_rs_level_notify,
notify_interval_elapsed,
run_rs_level_alert_tick,
)
def _row(**kwargs):
base = {
"upper": 2.174,
"lower": 1.694,
"notification_count": 0,
"max_notify": 3,
"notify_interval_min": 5,
"direction": "watch",
"last_notified_at": None,
"last_rs_bar_ts": None,
}
base.update(kwargs)
return base
class TestRsLevelAlertClaim(unittest.TestCase):
def setUp(self):
self.conn = sqlite3.connect(":memory:")
self.conn.execute(
"CREATE TABLE key_monitors ("
"id INTEGER PRIMARY KEY, notification_count INTEGER DEFAULT 0, "
"direction TEXT, last_notified_at TEXT, last_rs_bar_ts INTEGER)"
)
self.conn.execute(
"INSERT INTO key_monitors (id, notification_count, direction) VALUES (1, 0, 'watch')"
)
self.conn.commit()
def test_claim_advances_once_per_index(self):
ok1 = claim_rs_level_notify(
self.conn, 1, 1, "long", "2026-06-02 00:25:00", 1000, prior_count=0
)
self.conn.commit()
self.assertTrue(ok1)
ok_dup = claim_rs_level_notify(
self.conn, 1, 1, "long", "2026-06-02 00:25:03", 1000, prior_count=0
)
self.assertFalse(ok_dup)
ok2 = claim_rs_level_notify(
self.conn, 1, 2, "long", "2026-06-02 00:30:00", 1000, prior_count=1
)
self.conn.commit()
self.assertTrue(ok2)
row = self.conn.execute(
"SELECT notification_count FROM key_monitors WHERE id=1"
).fetchone()
self.assertEqual(row[0], 2)
def test_second_push_requires_interval(self):
now = datetime(2026, 6, 2, 0, 26, 0)
row = _row(
notification_count=1,
direction="long",
last_notified_at="2026-06-02 00:25:00",
)
tick = run_rs_level_alert_tick(row, 2.18, 1000, now, default_max_notify=3, default_interval_min=5)
self.assertIsNone(tick)
later = datetime(2026, 6, 2, 0, 30, 1)
tick2 = run_rs_level_alert_tick(
row, 2.18, 1000, later, default_max_notify=3, default_interval_min=5
)
self.assertIsNotNone(tick2)
self.assertEqual(tick2["notify_index"], 2)
self.assertEqual(tick2["prior_count"], 1)
def test_notify_interval_invalid_timestamp_does_not_spam(self):
now = datetime(2026, 6, 2, 1, 0, 0)
self.assertFalse(notify_interval_elapsed("not-a-date", 5, now))
if __name__ == "__main__":
unittest.main()
"""阻力/支撑提醒:占位与间隔防重复推送。"""
from __future__ import annotations
import sqlite3
import unittest
from datetime import datetime, timedelta
from lib.key_monitor.key_monitor_lib import (
claim_rs_level_notify,
notify_interval_elapsed,
run_rs_level_alert_tick,
)
def _row(**kwargs):
base = {
"upper": 2.174,
"lower": 1.694,
"notification_count": 0,
"max_notify": 3,
"notify_interval_min": 5,
"direction": "watch",
"last_notified_at": None,
"last_rs_bar_ts": None,
}
base.update(kwargs)
return base
class TestRsLevelAlertClaim(unittest.TestCase):
def setUp(self):
self.conn = sqlite3.connect(":memory:")
self.conn.execute(
"CREATE TABLE key_monitors ("
"id INTEGER PRIMARY KEY, notification_count INTEGER DEFAULT 0, "
"direction TEXT, last_notified_at TEXT, last_rs_bar_ts INTEGER)"
)
self.conn.execute(
"INSERT INTO key_monitors (id, notification_count, direction) VALUES (1, 0, 'watch')"
)
self.conn.commit()
def test_claim_advances_once_per_index(self):
ok1 = claim_rs_level_notify(
self.conn, 1, 1, "long", "2026-06-02 00:25:00", 1000, prior_count=0
)
self.conn.commit()
self.assertTrue(ok1)
ok_dup = claim_rs_level_notify(
self.conn, 1, 1, "long", "2026-06-02 00:25:03", 1000, prior_count=0
)
self.assertFalse(ok_dup)
ok2 = claim_rs_level_notify(
self.conn, 1, 2, "long", "2026-06-02 00:30:00", 1000, prior_count=1
)
self.conn.commit()
self.assertTrue(ok2)
row = self.conn.execute(
"SELECT notification_count FROM key_monitors WHERE id=1"
).fetchone()
self.assertEqual(row[0], 2)
def test_second_push_requires_interval(self):
now = datetime(2026, 6, 2, 0, 26, 0)
row = _row(
notification_count=1,
direction="long",
last_notified_at="2026-06-02 00:25:00",
)
tick = run_rs_level_alert_tick(row, 2.18, 1000, now, default_max_notify=3, default_interval_min=5)
self.assertIsNone(tick)
later = datetime(2026, 6, 2, 0, 30, 1)
tick2 = run_rs_level_alert_tick(
row, 2.18, 1000, later, default_max_notify=3, default_interval_min=5
)
self.assertIsNotNone(tick2)
self.assertEqual(tick2["notify_index"], 2)
self.assertEqual(tick2["prior_count"], 1)
def test_notify_interval_invalid_timestamp_does_not_spam(self):
now = datetime(2026, 6, 2, 1, 0, 0)
self.assertFalse(notify_interval_elapsed("not-a-date", 5, now))
if __name__ == "__main__":
unittest.main()
+27 -27
View File
@@ -1,27 +1,27 @@
import unittest
from key_monitor_lib import (
KEY_MONITOR_RS_TYPE,
is_rs_key_monitor_type,
rs_monitor_type_for_storage,
rs_monitor_type_label,
)
class KeyMonitorRsTypeTests(unittest.TestCase):
def test_legacy_types_still_recognized(self):
self.assertTrue(is_rs_key_monitor_type("关键阻力位"))
self.assertTrue(is_rs_key_monitor_type("关键支撑位"))
def test_storage_normalizes_to_unified_type(self):
self.assertEqual(rs_monitor_type_for_storage("关键阻力位"), KEY_MONITOR_RS_TYPE)
self.assertEqual(rs_monitor_type_for_storage("关键支撑位"), KEY_MONITOR_RS_TYPE)
self.assertEqual(rs_monitor_type_for_storage(KEY_MONITOR_RS_TYPE), KEY_MONITOR_RS_TYPE)
def test_label_merges_legacy_display(self):
self.assertEqual(rs_monitor_type_label("关键阻力位"), KEY_MONITOR_RS_TYPE)
self.assertEqual(rs_monitor_type_label("箱体突破"), "箱体突破")
if __name__ == "__main__":
unittest.main()
import unittest
from lib.key_monitor.key_monitor_lib import (
KEY_MONITOR_RS_TYPE,
is_rs_key_monitor_type,
rs_monitor_type_for_storage,
rs_monitor_type_label,
)
class KeyMonitorRsTypeTests(unittest.TestCase):
def test_legacy_types_still_recognized(self):
self.assertTrue(is_rs_key_monitor_type("关键阻力位"))
self.assertTrue(is_rs_key_monitor_type("关键支撑位"))
def test_storage_normalizes_to_unified_type(self):
self.assertEqual(rs_monitor_type_for_storage("关键阻力位"), KEY_MONITOR_RS_TYPE)
self.assertEqual(rs_monitor_type_for_storage("关键支撑位"), KEY_MONITOR_RS_TYPE)
self.assertEqual(rs_monitor_type_for_storage(KEY_MONITOR_RS_TYPE), KEY_MONITOR_RS_TYPE)
def test_label_merges_legacy_display(self):
self.assertEqual(rs_monitor_type_label("关键阻力位"), KEY_MONITOR_RS_TYPE)
self.assertEqual(rs_monitor_type_label("箱体突破"), "箱体突破")
if __name__ == "__main__":
unittest.main()
+32 -32
View File
@@ -1,32 +1,32 @@
from manual_sltp_lib import (
MANUAL_FIXED_RR_DEFAULT,
calc_tp_from_fixed_rr,
parse_fixed_rr,
resolve_open_sltp_prices,
)
def test_calc_tp_from_fixed_rr_long():
tp = calc_tp_from_fixed_rr("long", 100.0, 95.0, 1.5)
assert tp == 107.5
def test_calc_tp_from_fixed_rr_short():
tp = calc_tp_from_fixed_rr("short", 100.0, 105.0, 1.5)
assert tp == 92.5
def test_resolve_open_fixed_rr_mode():
sl, tp = resolve_open_sltp_prices(
"long",
100.0,
"fixed_rr",
{"sl": "95", "fixed_rr": "1.5"},
)
assert sl == 95.0
assert tp == 107.5
def test_parse_fixed_rr_default():
assert parse_fixed_rr(None) == MANUAL_FIXED_RR_DEFAULT
assert parse_fixed_rr("2") == 2.0
from lib.trade.manual_sltp_lib import (
MANUAL_FIXED_RR_DEFAULT,
calc_tp_from_fixed_rr,
parse_fixed_rr,
resolve_open_sltp_prices,
)
def test_calc_tp_from_fixed_rr_long():
tp = calc_tp_from_fixed_rr("long", 100.0, 95.0, 1.5)
assert tp == 107.5
def test_calc_tp_from_fixed_rr_short():
tp = calc_tp_from_fixed_rr("short", 100.0, 105.0, 1.5)
assert tp == 92.5
def test_resolve_open_fixed_rr_mode():
sl, tp = resolve_open_sltp_prices(
"long",
100.0,
"fixed_rr",
{"sl": "95", "fixed_rr": "1.5"},
)
assert sl == 95.0
assert tp == 107.5
def test_parse_fixed_rr_default():
assert parse_fixed_rr(None) == MANUAL_FIXED_RR_DEFAULT
assert parse_fixed_rr("2") == 2.0
+102 -102
View File
@@ -1,102 +1,102 @@
from order_monitor_display_lib import (
apply_order_price_display_fields,
is_sl_breakeven_secured,
monitor_open_stop_loss,
order_monitor_tpsl_needs_sync,
resolve_live_tpsl_prices,
sl_breakeven_from_exchange_tpsl,
snapshot_rr,
snapshot_stop_loss,
)
def _calc_rr(direction, entry, sl, tp):
if direction == "long":
risk = entry - sl
reward = tp - entry
else:
risk = sl - entry
reward = entry - tp
if risk <= 0 or reward <= 0:
return None
return round(reward / risk, 4)
def test_snapshot_stop_loss_prefers_initial():
assert snapshot_stop_loss(2.45, 2.6) == 2.45
assert snapshot_stop_loss(None, 2.6) == 2.6
def test_monitor_open_stop_loss_prefers_initial_snapshot():
row = {"initial_stop_loss": 64000, "stop_loss": 63200}
assert monitor_open_stop_loss(row) == 64000
def test_snapshot_rr_ignores_current_stop_after_manual_move():
rr = snapshot_rr(_calc_rr, "long", 2.726, 2.45, 2.65, 3.3)
assert rr is not None
assert rr > 2.0
def test_breakeven_long():
assert is_sl_breakeven_secured("long", 2.726, 2.726) is True
assert is_sl_breakeven_secured("long", 2.726, 2.75) is True
assert is_sl_breakeven_secured("long", 2.726, 2.45) is False
def test_breakeven_short():
assert is_sl_breakeven_secured("short", 72.73, 72.73) is True
assert is_sl_breakeven_secured("short", 72.73, 72.0) is True
assert is_sl_breakeven_secured("short", 72.73, 74.0) is False
def test_sl_breakeven_from_exchange_tpsl():
ok = sl_breakeven_from_exchange_tpsl(
"long",
2.726,
{"sl": {"trigger_price": 2.735}, "tp": {"trigger_price": 3.3}},
)
assert ok is True
def test_resolve_live_tpsl_prefers_exchange():
disp_sl, disp_tp, ex_sl, ex_tp = resolve_live_tpsl_prices(
1674,
1647.65,
{"sl": {"trigger_price": 1661}, "tp": {"trigger_price": 1647.65}},
)
assert disp_sl == 1661
assert disp_tp == 1647.65
assert ex_sl == 1661
assert ex_tp == 1647.65
def test_order_monitor_tpsl_needs_sync_detects_sl_change():
new_sl, new_tp, changed = order_monitor_tpsl_needs_sync(
1674,
1647.65,
{"sl": {"trigger_price": 1661}, "tp": {"trigger_price": 1647.65}},
)
assert changed is True
assert new_sl == 1661
assert new_tp == 1647.65
def test_apply_order_price_display_fields_live_sl():
payload = {}
apply_order_price_display_fields(
payload,
direction="short",
entry_price=1663.45,
initial_stop_loss=1674,
stop_loss=1674,
take_profit=1647.65,
calc_rr_ratio_fn=_calc_rr,
exchange_tpsl={"sl": {"trigger_price": 1661}, "tp": {"trigger_price": 1647.65}},
format_price_fn=lambda _s, v: f"{v:.2f}",
symbol="ETH/USDT:USDT",
)
assert payload["stop_loss"] == 1661
assert payload["stop_loss_display"] == "1661.00"
assert payload["sl_breakeven_secured"] is True
assert payload["rr_ratio"] is not None
from lib.trade.order_monitor_display_lib import (
apply_order_price_display_fields,
is_sl_breakeven_secured,
monitor_open_stop_loss,
order_monitor_tpsl_needs_sync,
resolve_live_tpsl_prices,
sl_breakeven_from_exchange_tpsl,
snapshot_rr,
snapshot_stop_loss,
)
def _calc_rr(direction, entry, sl, tp):
if direction == "long":
risk = entry - sl
reward = tp - entry
else:
risk = sl - entry
reward = entry - tp
if risk <= 0 or reward <= 0:
return None
return round(reward / risk, 4)
def test_snapshot_stop_loss_prefers_initial():
assert snapshot_stop_loss(2.45, 2.6) == 2.45
assert snapshot_stop_loss(None, 2.6) == 2.6
def test_monitor_open_stop_loss_prefers_initial_snapshot():
row = {"initial_stop_loss": 64000, "stop_loss": 63200}
assert monitor_open_stop_loss(row) == 64000
def test_snapshot_rr_ignores_current_stop_after_manual_move():
rr = snapshot_rr(_calc_rr, "long", 2.726, 2.45, 2.65, 3.3)
assert rr is not None
assert rr > 2.0
def test_breakeven_long():
assert is_sl_breakeven_secured("long", 2.726, 2.726) is True
assert is_sl_breakeven_secured("long", 2.726, 2.75) is True
assert is_sl_breakeven_secured("long", 2.726, 2.45) is False
def test_breakeven_short():
assert is_sl_breakeven_secured("short", 72.73, 72.73) is True
assert is_sl_breakeven_secured("short", 72.73, 72.0) is True
assert is_sl_breakeven_secured("short", 72.73, 74.0) is False
def test_sl_breakeven_from_exchange_tpsl():
ok = sl_breakeven_from_exchange_tpsl(
"long",
2.726,
{"sl": {"trigger_price": 2.735}, "tp": {"trigger_price": 3.3}},
)
assert ok is True
def test_resolve_live_tpsl_prefers_exchange():
disp_sl, disp_tp, ex_sl, ex_tp = resolve_live_tpsl_prices(
1674,
1647.65,
{"sl": {"trigger_price": 1661}, "tp": {"trigger_price": 1647.65}},
)
assert disp_sl == 1661
assert disp_tp == 1647.65
assert ex_sl == 1661
assert ex_tp == 1647.65
def test_order_monitor_tpsl_needs_sync_detects_sl_change():
new_sl, new_tp, changed = order_monitor_tpsl_needs_sync(
1674,
1647.65,
{"sl": {"trigger_price": 1661}, "tp": {"trigger_price": 1647.65}},
)
assert changed is True
assert new_sl == 1661
assert new_tp == 1647.65
def test_apply_order_price_display_fields_live_sl():
payload = {}
apply_order_price_display_fields(
payload,
direction="short",
entry_price=1663.45,
initial_stop_loss=1674,
stop_loss=1674,
take_profit=1647.65,
calc_rr_ratio_fn=_calc_rr,
exchange_tpsl={"sl": {"trigger_price": 1661}, "tp": {"trigger_price": 1647.65}},
format_price_fn=lambda _s, v: f"{v:.2f}",
symbol="ETH/USDT:USDT",
)
assert payload["stop_loss"] == 1661
assert payload["stop_loss_display"] == "1661.00"
assert payload["sl_breakeven_secured"] is True
assert payload["rr_ratio"] is not None
+78 -78
View File
@@ -1,78 +1,78 @@
import sqlite3
import unittest
from strategy_db import init_strategy_tables
from strategy_trade_labels import (
MONITOR_TYPE_TREND_PULLBACK,
count_position_limit_active_monitors,
)
def _mem_conn():
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
conn.execute(
"""CREATE TABLE order_monitors (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT,
direction TEXT,
status TEXT,
monitor_type TEXT,
key_signal_type TEXT,
trend_plan_id INTEGER
)"""
)
init_strategy_tables(conn)
return conn
class PositionLimitCountTests(unittest.TestCase):
def test_regular_monitor_counts(self):
conn = _mem_conn()
conn.execute(
"INSERT INTO order_monitors (symbol, status, monitor_type) VALUES ('ETH/USDT', 'active', '下单监控')"
)
conn.commit()
self.assertEqual(count_position_limit_active_monitors(conn), 1)
def test_trend_pullback_excluded(self):
conn = _mem_conn()
conn.execute(
"""INSERT INTO order_monitors
(symbol, status, monitor_type, trend_plan_id)
VALUES ('ETH/USDT', 'active', ?, 12)""",
(MONITOR_TYPE_TREND_PULLBACK,),
)
conn.commit()
self.assertEqual(count_position_limit_active_monitors(conn), 0)
def test_active_roll_group_still_counts_regular_monitor(self):
conn = _mem_conn()
conn.execute(
"INSERT INTO order_monitors (id, symbol, status, monitor_type) VALUES (1, 'ETH/USDT', 'active', '下单监控')"
)
conn.execute(
"""INSERT INTO roll_groups
(order_monitor_id, symbol, direction, status)
VALUES (1, 'ETH/USDT', 'long', 'active')"""
)
conn.commit()
self.assertEqual(count_position_limit_active_monitors(conn), 1)
def test_mixed_monitors(self):
conn = _mem_conn()
conn.execute(
"INSERT INTO order_monitors (symbol, status, monitor_type) VALUES ('BTC/USDT', 'active', '下单监控')"
)
conn.execute(
"""INSERT INTO order_monitors
(symbol, status, monitor_type, trend_plan_id)
VALUES ('ETH/USDT', 'active', ?, 3)""",
(MONITOR_TYPE_TREND_PULLBACK,),
)
conn.commit()
self.assertEqual(count_position_limit_active_monitors(conn), 1)
if __name__ == "__main__":
unittest.main()
import sqlite3
import unittest
from lib.strategy.strategy_db import init_strategy_tables
from lib.strategy.strategy_trade_labels import (
MONITOR_TYPE_TREND_PULLBACK,
count_position_limit_active_monitors,
)
def _mem_conn():
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
conn.execute(
"""CREATE TABLE order_monitors (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT,
direction TEXT,
status TEXT,
monitor_type TEXT,
key_signal_type TEXT,
trend_plan_id INTEGER
)"""
)
init_strategy_tables(conn)
return conn
class PositionLimitCountTests(unittest.TestCase):
def test_regular_monitor_counts(self):
conn = _mem_conn()
conn.execute(
"INSERT INTO order_monitors (symbol, status, monitor_type) VALUES ('ETH/USDT', 'active', '下单监控')"
)
conn.commit()
self.assertEqual(count_position_limit_active_monitors(conn), 1)
def test_trend_pullback_excluded(self):
conn = _mem_conn()
conn.execute(
"""INSERT INTO order_monitors
(symbol, status, monitor_type, trend_plan_id)
VALUES ('ETH/USDT', 'active', ?, 12)""",
(MONITOR_TYPE_TREND_PULLBACK,),
)
conn.commit()
self.assertEqual(count_position_limit_active_monitors(conn), 0)
def test_active_roll_group_still_counts_regular_monitor(self):
conn = _mem_conn()
conn.execute(
"INSERT INTO order_monitors (id, symbol, status, monitor_type) VALUES (1, 'ETH/USDT', 'active', '下单监控')"
)
conn.execute(
"""INSERT INTO roll_groups
(order_monitor_id, symbol, direction, status)
VALUES (1, 'ETH/USDT', 'long', 'active')"""
)
conn.commit()
self.assertEqual(count_position_limit_active_monitors(conn), 1)
def test_mixed_monitors(self):
conn = _mem_conn()
conn.execute(
"INSERT INTO order_monitors (symbol, status, monitor_type) VALUES ('BTC/USDT', 'active', '下单监控')"
)
conn.execute(
"""INSERT INTO order_monitors
(symbol, status, monitor_type, trend_plan_id)
VALUES ('ETH/USDT', 'active', ?, 3)""",
(MONITOR_TYPE_TREND_PULLBACK,),
)
conn.commit()
self.assertEqual(count_position_limit_active_monitors(conn), 1)
if __name__ == "__main__":
unittest.main()
+34 -34
View File
@@ -1,34 +1,34 @@
"""全仓 / 以损定仓 风险展示文案。"""
from __future__ import annotations
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from position_sizing_lib import ( # noqa: E402
format_risk_display_text,
risk_percent_for_storage,
)
class TestPositionSizingRiskDisplay(unittest.TestCase):
def test_full_margin_shows_amount_only(self):
self.assertEqual(
format_risk_display_text("full_margin", 1.0, 2.58, decimals=2),
"2.58U",
)
self.assertIsNone(risk_percent_for_storage("full_margin", 1.0))
def test_risk_mode_shows_percent_and_amount(self):
self.assertEqual(
format_risk_display_text("risk", 2.0, 10.5, decimals=2),
"2%≈10.5U",
)
self.assertEqual(risk_percent_for_storage("risk", 2.0), 2.0)
if __name__ == "__main__":
unittest.main()
"""全仓 / 以损定仓 风险展示文案。"""
from __future__ import annotations
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from lib.trade.position_sizing_lib import ( # noqa: E402
format_risk_display_text,
risk_percent_for_storage,
)
class TestPositionSizingRiskDisplay(unittest.TestCase):
def test_full_margin_shows_amount_only(self):
self.assertEqual(
format_risk_display_text("full_margin", 1.0, 2.58, decimals=2),
"2.58U",
)
self.assertIsNone(risk_percent_for_storage("full_margin", 1.0))
def test_risk_mode_shows_percent_and_amount(self):
self.assertEqual(
format_risk_display_text("risk", 2.0, 10.5, decimals=2),
"2%≈10.5U",
)
self.assertEqual(risk_percent_for_storage("risk", 2.0), 2.0)
if __name__ == "__main__":
unittest.main()
+112 -112
View File
@@ -1,112 +1,112 @@
from strategy_roll_lib import (
preview_roll,
roll_breakout_invalidate,
roll_breakout_trigger_crossed,
roll_fib_invalidate,
roll_fib_trigger_crossed,
solve_add_amount_for_total_risk,
validate_roll_geometry,
)
def test_solve_add_amount_long_one_risk():
q2, err = solve_add_amount_for_total_risk(
"long", 1.0, 3000.0, 3100.0, 2950.0, 200.0, 1.0
)
assert err is None
avg = (1 * 3000 + q2 * 3100) / (1 + q2)
loss = (avg - 2950) * (1 + q2)
assert abs(loss - 200.0) < 0.01
def test_preview_roll_market_short():
preview, err = preview_roll(
direction="short",
symbol="HYPE/USDT",
qty_existing=3.0,
entry_existing=65.0,
initial_take_profit=60.0,
add_mode="market",
new_stop_loss=66.5,
risk_percent=2.0,
capital_base_usdt=1000.0,
add_price=64.0,
legs_done=1,
)
assert err is None
assert preview["add_mode_label"] == "市价加仓"
sl = preview["new_stop_loss"]
avg = preview["avg_entry_after"]
qty = preview["qty_after"]
loss = (sl - avg) * qty
assert abs(loss - 20.0) < 0.01
def test_fib_cross_long_down():
assert roll_fib_trigger_crossed("long", 101.0, 100.0, 100.5) is True
assert roll_fib_trigger_crossed("long", 100.6, 100.6, 100.5) is False
def test_breakout_cross_long_up():
assert roll_breakout_trigger_crossed("long", 99.0, 100.5, 100.0) is True
assert roll_breakout_invalidate("long", 98.0, 99.0) is True
assert roll_fib_invalidate("long", 110.0, 105.0, 95.0) is True
def test_preview_breakout_mode_label():
preview, err = preview_roll(
direction="long",
symbol="ETH/USDT",
qty_existing=1.0,
entry_existing=3000.0,
initial_take_profit=3500.0,
add_mode="breakout",
new_stop_loss=2980.0,
breakthrough_price=3100.0,
risk_percent=10.0,
capital_base_usdt=1000.0,
add_price=3050.0,
)
assert err is None
assert preview["add_mode_label"] == "突破加仓"
def test_breakout_geometry_short_mark_above_breakout():
err = validate_roll_geometry(
"short",
"breakout",
new_stop_loss=568.0,
breakthrough_price=551.0,
entry_existing=560.0,
initial_take_profit=540.0,
mark_price=560.0,
)
assert err is None
def test_breakout_geometry_short_rejects_mark_at_or_below_breakout():
err = validate_roll_geometry(
"short",
"breakout",
new_stop_loss=568.0,
breakthrough_price=551.0,
entry_existing=560.0,
initial_take_profit=540.0,
mark_price=551.0,
)
assert err is not None
assert "高于突破价" in err
def test_breakout_geometry_long_rejects_mark_at_or_above_breakout():
err = validate_roll_geometry(
"long",
"breakout",
new_stop_loss=2980.0,
breakthrough_price=3100.0,
entry_existing=3000.0,
initial_take_profit=3500.0,
mark_price=3100.0,
)
assert err is not None
assert "低于突破价" in err
from lib.strategy.strategy_roll_lib import (
preview_roll,
roll_breakout_invalidate,
roll_breakout_trigger_crossed,
roll_fib_invalidate,
roll_fib_trigger_crossed,
solve_add_amount_for_total_risk,
validate_roll_geometry,
)
def test_solve_add_amount_long_one_risk():
q2, err = solve_add_amount_for_total_risk(
"long", 1.0, 3000.0, 3100.0, 2950.0, 200.0, 1.0
)
assert err is None
avg = (1 * 3000 + q2 * 3100) / (1 + q2)
loss = (avg - 2950) * (1 + q2)
assert abs(loss - 200.0) < 0.01
def test_preview_roll_market_short():
preview, err = preview_roll(
direction="short",
symbol="HYPE/USDT",
qty_existing=3.0,
entry_existing=65.0,
initial_take_profit=60.0,
add_mode="market",
new_stop_loss=66.5,
risk_percent=2.0,
capital_base_usdt=1000.0,
add_price=64.0,
legs_done=1,
)
assert err is None
assert preview["add_mode_label"] == "市价加仓"
sl = preview["new_stop_loss"]
avg = preview["avg_entry_after"]
qty = preview["qty_after"]
loss = (sl - avg) * qty
assert abs(loss - 20.0) < 0.01
def test_fib_cross_long_down():
assert roll_fib_trigger_crossed("long", 101.0, 100.0, 100.5) is True
assert roll_fib_trigger_crossed("long", 100.6, 100.6, 100.5) is False
def test_breakout_cross_long_up():
assert roll_breakout_trigger_crossed("long", 99.0, 100.5, 100.0) is True
assert roll_breakout_invalidate("long", 98.0, 99.0) is True
assert roll_fib_invalidate("long", 110.0, 105.0, 95.0) is True
def test_preview_breakout_mode_label():
preview, err = preview_roll(
direction="long",
symbol="ETH/USDT",
qty_existing=1.0,
entry_existing=3000.0,
initial_take_profit=3500.0,
add_mode="breakout",
new_stop_loss=2980.0,
breakthrough_price=3100.0,
risk_percent=10.0,
capital_base_usdt=1000.0,
add_price=3050.0,
)
assert err is None
assert preview["add_mode_label"] == "突破加仓"
def test_breakout_geometry_short_mark_above_breakout():
err = validate_roll_geometry(
"short",
"breakout",
new_stop_loss=568.0,
breakthrough_price=551.0,
entry_existing=560.0,
initial_take_profit=540.0,
mark_price=560.0,
)
assert err is None
def test_breakout_geometry_short_rejects_mark_at_or_below_breakout():
err = validate_roll_geometry(
"short",
"breakout",
new_stop_loss=568.0,
breakthrough_price=551.0,
entry_existing=560.0,
initial_take_profit=540.0,
mark_price=551.0,
)
assert err is not None
assert "高于突破价" in err
def test_breakout_geometry_long_rejects_mark_at_or_above_breakout():
err = validate_roll_geometry(
"long",
"breakout",
new_stop_loss=2980.0,
breakthrough_price=3100.0,
entry_existing=3000.0,
initial_take_profit=3500.0,
mark_price=3100.0,
)
assert err is not None
assert "低于突破价" in err
+44 -44
View File
@@ -1,44 +1,44 @@
"""strategy_roll_ui_lib 单元测试。"""
from __future__ import annotations
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
import strategy_roll_ui_lib as roll_ui
def test_compute_roll_chain_metrics_short():
group = {
"id": 1,
"direction": "short",
"initial_take_profit": 60.0,
}
legs = [
{"id": 10, "leg_index": 1, "amount": 3.0, "fill_price": 65.0, "status": "filled"},
{"id": 11, "leg_index": 2, "amount": 5.0, "fill_price": 64.0, "status": "filled"},
]
per_leg, group_metrics = roll_ui.compute_roll_chain_metrics(
group,
legs,
qty_live=8.0,
entry_live=63.5,
monitor={"trigger_price": 66.0, "order_amount": 3.0},
)
assert per_leg[10]["avg_entry_after"] is not None
assert per_leg[11]["avg_entry_after"] is not None
assert group_metrics["reward_at_tp_usdt"] is not None
assert per_leg[11]["reward_at_tp_usdt"] >= per_leg[10]["reward_at_tp_usdt"]
def test_infer_initial_position_from_live():
legs = [{"amount": 2.0, "fill_price": 64.0, "status": "filled"}]
q0, e0 = roll_ui.infer_initial_position(5.0, 63.0, legs)
assert q0 == 3.0
assert abs(e0 - 62.3333333333) < 0.001
def test_reward_at_tp_long():
assert roll_ui.reward_at_tp_usdt("long", 100.0, 110.0, 2.0) == 20.0
"""strategy_roll_ui_lib 单元测试。"""
from __future__ import annotations
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
import lib.strategy.strategy_roll_ui_lib as roll_ui
def test_compute_roll_chain_metrics_short():
group = {
"id": 1,
"direction": "short",
"initial_take_profit": 60.0,
}
legs = [
{"id": 10, "leg_index": 1, "amount": 3.0, "fill_price": 65.0, "status": "filled"},
{"id": 11, "leg_index": 2, "amount": 5.0, "fill_price": 64.0, "status": "filled"},
]
per_leg, group_metrics = roll_ui.compute_roll_chain_metrics(
group,
legs,
qty_live=8.0,
entry_live=63.5,
monitor={"trigger_price": 66.0, "order_amount": 3.0},
)
assert per_leg[10]["avg_entry_after"] is not None
assert per_leg[11]["avg_entry_after"] is not None
assert group_metrics["reward_at_tp_usdt"] is not None
assert per_leg[11]["reward_at_tp_usdt"] >= per_leg[10]["reward_at_tp_usdt"]
def test_infer_initial_position_from_live():
legs = [{"amount": 2.0, "fill_price": 64.0, "status": "filled"}]
q0, e0 = roll_ui.infer_initial_position(5.0, 63.0, legs)
assert q0 == 3.0
assert abs(e0 - 62.3333333333) < 0.001
def test_reward_at_tp_long():
assert roll_ui.reward_at_tp_usdt("long", 100.0, 110.0, 2.0) == 20.0
+183 -183
View File
@@ -1,183 +1,183 @@
"""策略快照:同一计划同结果不重复写入。"""
from __future__ import annotations
import json
import sqlite3
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from strategy_snapshot_lib import ( # noqa: E402
STRATEGY_TREND,
dedupe_strategy_snapshots,
init_strategy_snapshot_table,
list_strategy_snapshots,
save_trend_plan_snapshot,
)
def _mem_conn() -> sqlite3.Connection:
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
init_strategy_snapshot_table(conn)
return conn
def test_save_trend_plan_snapshot_skips_duplicate_result():
conn = _mem_conn()
plan = {
"id": 42,
"symbol": "ONDO/USDT",
"exchange_symbol": "ONDO/USDT:USDT",
"direction": "short",
"status": "active",
"opened_at": "2026-06-08 08:00:00",
"legs_done": 4,
"dca_legs": 4,
"first_order_done": 1,
"grid_prices_json": "[]",
"leg_amounts_json": "[]",
}
cfg = {"app_module": type("M", (), {"app_now_str": staticmethod(lambda: "2026-06-08 08:41:00")})()}
save_trend_plan_snapshot(cfg, conn, plan, result_label="止损", pnl_amount=-2.3)
save_trend_plan_snapshot(cfg, conn, plan, result_label="止损", pnl_amount=-2.4)
conn.commit()
rows = conn.execute(
"SELECT COUNT(*) AS c FROM strategy_trade_snapshots WHERE source_id=? AND result_label=?",
(42, "止损"),
).fetchone()
assert int(rows["c"]) == 1
def test_dedupe_strategy_snapshots_handles_many_duplicates():
conn = _mem_conn()
payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False)
for snap_id in range(1, 46):
conn.execute(
"""INSERT INTO strategy_trade_snapshots (
id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount
) VALUES (?,?,?,?,?,?,?,?,?)""",
(
snap_id,
STRATEGY_TREND,
99,
"ONDO/USDT",
"止损",
payload,
"2026-06-08 08:41:00",
"2026-06-08 08:41:00",
-2.2,
),
)
conn.commit()
removed = dedupe_strategy_snapshots(conn)
conn.commit()
assert removed == 44
row = conn.execute(
"SELECT COUNT(*) AS c FROM strategy_trade_snapshots WHERE source_id=?",
(99,),
).fetchone()
assert int(row["c"]) == 1
def test_dedupe_strategy_snapshots_keeps_latest_id():
conn = _mem_conn()
payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False)
for snap_id, pnl in ((1, -2.23), (2, -2.31), (3, -2.38)):
conn.execute(
"""INSERT INTO strategy_trade_snapshots (
id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount
) VALUES (?,?,?,?,?,?,?,?,?)""",
(
snap_id,
STRATEGY_TREND,
5,
"ONDO/USDT",
"止损",
payload,
"2026-06-08 08:41:00",
"2026-06-08 08:41:00",
pnl,
),
)
conn.commit()
removed = dedupe_strategy_snapshots(conn)
conn.commit()
assert removed == 2
row = conn.execute(
"SELECT id, pnl_amount FROM strategy_trade_snapshots WHERE source_id=?",
(5,),
).fetchone()
assert int(row["id"]) == 3
assert abs(float(row["pnl_amount"]) - (-2.38)) < 1e-6
def test_list_strategy_snapshots_hides_duplicate_keys():
conn = _mem_conn()
payload = json.dumps({"symbol": "ONDO/USDT", "dca_levels": []}, ensure_ascii=False)
for snap_id in (10, 11, 12):
conn.execute(
"""INSERT INTO strategy_trade_snapshots (
id, strategy_type, source_id, symbol, direction, result_label,
snapshot_json, closed_at, created_at, pnl_amount
) VALUES (?,?,?,?,?,?,?,?,?,?)""",
(
snap_id,
STRATEGY_TREND,
7,
"ONDO/USDT",
"short",
"止损",
payload,
"2026-06-08 08:41:00",
"2026-06-08 08:41:00",
-2.2,
),
)
conn.commit()
rows = list_strategy_snapshots(conn, limit=50)
stop_rows = [r for r in rows if int(r.get("source_id") or 0) == 7]
assert len(stop_rows) == 1
assert int(stop_rows[0]["id"]) == 12
def test_dedupe_keeps_manual_over_stop_loss():
conn = _mem_conn()
payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False)
for snap_id, label in ((10, "止损"), (11, "手动平仓")):
conn.execute(
"""INSERT INTO strategy_trade_snapshots (
id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount
) VALUES (?,?,?,?,?,?,?,?,?)""",
(
snap_id,
STRATEGY_TREND,
7,
"ONDO/USDT",
label,
payload,
"2026-06-08 08:44:00",
"2026-06-08 08:44:00",
-2.23,
),
)
conn.commit()
removed = dedupe_strategy_snapshots(conn)
conn.commit()
assert removed == 1
row = conn.execute(
"SELECT result_label FROM strategy_trade_snapshots WHERE source_id=?",
(7,),
).fetchone()
assert row["result_label"] == "手动平仓"
if __name__ == "__main__":
test_save_trend_plan_snapshot_skips_duplicate_result()
test_dedupe_strategy_snapshots_handles_many_duplicates()
test_dedupe_strategy_snapshots_keeps_latest_id()
test_list_strategy_snapshots_hides_duplicate_keys()
test_dedupe_keeps_manual_over_stop_loss()
print("all ok")
"""策略快照:同一计划同结果不重复写入。"""
from __future__ import annotations
import json
import sqlite3
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from lib.strategy.strategy_snapshot_lib import ( # noqa: E402
STRATEGY_TREND,
dedupe_strategy_snapshots,
init_strategy_snapshot_table,
list_strategy_snapshots,
save_trend_plan_snapshot,
)
def _mem_conn() -> sqlite3.Connection:
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
init_strategy_snapshot_table(conn)
return conn
def test_save_trend_plan_snapshot_skips_duplicate_result():
conn = _mem_conn()
plan = {
"id": 42,
"symbol": "ONDO/USDT",
"exchange_symbol": "ONDO/USDT:USDT",
"direction": "short",
"status": "active",
"opened_at": "2026-06-08 08:00:00",
"legs_done": 4,
"dca_legs": 4,
"first_order_done": 1,
"grid_prices_json": "[]",
"leg_amounts_json": "[]",
}
cfg = {"app_module": type("M", (), {"app_now_str": staticmethod(lambda: "2026-06-08 08:41:00")})()}
save_trend_plan_snapshot(cfg, conn, plan, result_label="止损", pnl_amount=-2.3)
save_trend_plan_snapshot(cfg, conn, plan, result_label="止损", pnl_amount=-2.4)
conn.commit()
rows = conn.execute(
"SELECT COUNT(*) AS c FROM strategy_trade_snapshots WHERE source_id=? AND result_label=?",
(42, "止损"),
).fetchone()
assert int(rows["c"]) == 1
def test_dedupe_strategy_snapshots_handles_many_duplicates():
conn = _mem_conn()
payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False)
for snap_id in range(1, 46):
conn.execute(
"""INSERT INTO strategy_trade_snapshots (
id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount
) VALUES (?,?,?,?,?,?,?,?,?)""",
(
snap_id,
STRATEGY_TREND,
99,
"ONDO/USDT",
"止损",
payload,
"2026-06-08 08:41:00",
"2026-06-08 08:41:00",
-2.2,
),
)
conn.commit()
removed = dedupe_strategy_snapshots(conn)
conn.commit()
assert removed == 44
row = conn.execute(
"SELECT COUNT(*) AS c FROM strategy_trade_snapshots WHERE source_id=?",
(99,),
).fetchone()
assert int(row["c"]) == 1
def test_dedupe_strategy_snapshots_keeps_latest_id():
conn = _mem_conn()
payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False)
for snap_id, pnl in ((1, -2.23), (2, -2.31), (3, -2.38)):
conn.execute(
"""INSERT INTO strategy_trade_snapshots (
id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount
) VALUES (?,?,?,?,?,?,?,?,?)""",
(
snap_id,
STRATEGY_TREND,
5,
"ONDO/USDT",
"止损",
payload,
"2026-06-08 08:41:00",
"2026-06-08 08:41:00",
pnl,
),
)
conn.commit()
removed = dedupe_strategy_snapshots(conn)
conn.commit()
assert removed == 2
row = conn.execute(
"SELECT id, pnl_amount FROM strategy_trade_snapshots WHERE source_id=?",
(5,),
).fetchone()
assert int(row["id"]) == 3
assert abs(float(row["pnl_amount"]) - (-2.38)) < 1e-6
def test_list_strategy_snapshots_hides_duplicate_keys():
conn = _mem_conn()
payload = json.dumps({"symbol": "ONDO/USDT", "dca_levels": []}, ensure_ascii=False)
for snap_id in (10, 11, 12):
conn.execute(
"""INSERT INTO strategy_trade_snapshots (
id, strategy_type, source_id, symbol, direction, result_label,
snapshot_json, closed_at, created_at, pnl_amount
) VALUES (?,?,?,?,?,?,?,?,?,?)""",
(
snap_id,
STRATEGY_TREND,
7,
"ONDO/USDT",
"short",
"止损",
payload,
"2026-06-08 08:41:00",
"2026-06-08 08:41:00",
-2.2,
),
)
conn.commit()
rows = list_strategy_snapshots(conn, limit=50)
stop_rows = [r for r in rows if int(r.get("source_id") or 0) == 7]
assert len(stop_rows) == 1
assert int(stop_rows[0]["id"]) == 12
def test_dedupe_keeps_manual_over_stop_loss():
conn = _mem_conn()
payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False)
for snap_id, label in ((10, "止损"), (11, "手动平仓")):
conn.execute(
"""INSERT INTO strategy_trade_snapshots (
id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount
) VALUES (?,?,?,?,?,?,?,?,?)""",
(
snap_id,
STRATEGY_TREND,
7,
"ONDO/USDT",
label,
payload,
"2026-06-08 08:44:00",
"2026-06-08 08:44:00",
-2.23,
),
)
conn.commit()
removed = dedupe_strategy_snapshots(conn)
conn.commit()
assert removed == 1
row = conn.execute(
"SELECT result_label FROM strategy_trade_snapshots WHERE source_id=?",
(7,),
).fetchone()
assert row["result_label"] == "手动平仓"
if __name__ == "__main__":
test_save_trend_plan_snapshot_skips_duplicate_result()
test_dedupe_strategy_snapshots_handles_many_duplicates()
test_dedupe_strategy_snapshots_keeps_latest_id()
test_list_strategy_snapshots_hides_duplicate_keys()
test_dedupe_keeps_manual_over_stop_loss()
print("all ok")
+48 -48
View File
@@ -1,48 +1,48 @@
import unittest
from trade_exchange_stats_lib import (
aggregate_bilateral_stats,
commission_usdt_from_fill,
filter_position_lifecycle_fills,
merge_commission_prefer_income,
quote_turnover_usdt_from_fill,
)
class TradeExchangeStatsTests(unittest.TestCase):
def test_turnover_from_cost(self):
t = {"cost": 1000.0, "price": 50, "amount": 20}
self.assertEqual(quote_turnover_usdt_from_fill(t), 1000.0)
def test_commission_from_fee(self):
t = {"fee": {"cost": -0.42, "currency": "USDT"}}
self.assertEqual(commission_usdt_from_fill(t), 0.42)
def test_bilateral_aggregate(self):
fills = [
{"side": "buy", "cost": 500, "fee": {"cost": -0.2, "currency": "USDT"}, "timestamp": 1000},
{"side": "sell", "cost": 520, "fee": {"cost": -0.21, "currency": "USDT"}, "timestamp": 2000},
]
stats = aggregate_bilateral_stats(fills)
self.assertIsNotNone(stats)
self.assertEqual(stats["exchange_turnover_usdt"], 1020.0)
self.assertEqual(stats["exchange_commission_usdt"], 0.41)
def test_filter_long_lifecycle(self):
base = 1_700_000_000_000
trades = [
{"side": "buy", "timestamp": base, "cost": 100},
{"side": "sell", "timestamp": base + 60_000, "cost": 110},
{"side": "buy", "timestamp": base + 120_000, "cost": 999},
]
got = filter_position_lifecycle_fills(
trades, "long", base - 1000, base + 90_000, close_buffer_ms=0
)
self.assertEqual(len(got), 2)
def test_prefer_income_commission(self):
self.assertEqual(merge_commission_prefer_income(0.3, 0.45), 0.45)
if __name__ == "__main__":
unittest.main()
import unittest
from lib.trade.trade_exchange_stats_lib import (
aggregate_bilateral_stats,
commission_usdt_from_fill,
filter_position_lifecycle_fills,
merge_commission_prefer_income,
quote_turnover_usdt_from_fill,
)
class TradeExchangeStatsTests(unittest.TestCase):
def test_turnover_from_cost(self):
t = {"cost": 1000.0, "price": 50, "amount": 20}
self.assertEqual(quote_turnover_usdt_from_fill(t), 1000.0)
def test_commission_from_fee(self):
t = {"fee": {"cost": -0.42, "currency": "USDT"}}
self.assertEqual(commission_usdt_from_fill(t), 0.42)
def test_bilateral_aggregate(self):
fills = [
{"side": "buy", "cost": 500, "fee": {"cost": -0.2, "currency": "USDT"}, "timestamp": 1000},
{"side": "sell", "cost": 520, "fee": {"cost": -0.21, "currency": "USDT"}, "timestamp": 2000},
]
stats = aggregate_bilateral_stats(fills)
self.assertIsNotNone(stats)
self.assertEqual(stats["exchange_turnover_usdt"], 1020.0)
self.assertEqual(stats["exchange_commission_usdt"], 0.41)
def test_filter_long_lifecycle(self):
base = 1_700_000_000_000
trades = [
{"side": "buy", "timestamp": base, "cost": 100},
{"side": "sell", "timestamp": base + 60_000, "cost": 110},
{"side": "buy", "timestamp": base + 120_000, "cost": 999},
]
got = filter_position_lifecycle_fills(
trades, "long", base - 1000, base + 90_000, close_buffer_ms=0
)
self.assertEqual(len(got), 2)
def test_prefer_income_commission(self):
self.assertEqual(merge_commission_prefer_income(0.3, 0.45), 0.45)
if __name__ == "__main__":
unittest.main()
+30 -30
View File
@@ -1,30 +1,30 @@
from trade_result_lib import normalize_result_with_pnl, normalize_display_result, is_winning_pnl
def test_stop_loss_with_profit_becomes_trailing_tp():
assert normalize_result_with_pnl("止损", 4.33) == "移动止盈"
def test_manual_close_unchanged_even_with_profit():
assert normalize_result_with_pnl("手动平仓", 10) == "手动平仓"
def test_stop_loss_with_loss_unchanged():
assert normalize_result_with_pnl("止损", -2.5) == "止损"
def test_take_profit_unchanged():
assert normalize_result_with_pnl("止盈", 5) == "止盈"
def test_external_close_becomes_manual_close():
assert normalize_display_result("外部平仓") == "手动平仓"
assert normalize_result_with_pnl("外部平仓", 2.5) == "手动平仓"
assert normalize_result_with_pnl("外部平仓(自动同步)", -1) == "手动平仓"
def test_winning_pnl_positive_only():
assert is_winning_pnl(2.96) is True
assert is_winning_pnl(0) is False
assert is_winning_pnl(-1.05) is False
assert is_winning_pnl(None) is False
from lib.trade.trade_result_lib import normalize_result_with_pnl, normalize_display_result, is_winning_pnl
def test_stop_loss_with_profit_becomes_trailing_tp():
assert normalize_result_with_pnl("止损", 4.33) == "移动止盈"
def test_manual_close_unchanged_even_with_profit():
assert normalize_result_with_pnl("手动平仓", 10) == "手动平仓"
def test_stop_loss_with_loss_unchanged():
assert normalize_result_with_pnl("止损", -2.5) == "止损"
def test_take_profit_unchanged():
assert normalize_result_with_pnl("止盈", 5) == "止盈"
def test_external_close_becomes_manual_close():
assert normalize_display_result("外部平仓") == "手动平仓"
assert normalize_result_with_pnl("外部平仓", 2.5) == "手动平仓"
assert normalize_result_with_pnl("外部平仓(自动同步)", -1) == "手动平仓"
def test_winning_pnl_positive_only():
assert is_winning_pnl(2.96) is True
assert is_winning_pnl(0) is False
assert is_winning_pnl(-1.05) is False
assert is_winning_pnl(None) is False
+90 -90
View File
@@ -1,90 +1,90 @@
import unittest
from types import SimpleNamespace
from datetime import datetime
from trade_stats_calendar_lib import (
build_initial_stats_calendar,
build_stats_calendar_bootstrap,
build_trade_stats_calendar,
)
def _row(**kwargs):
base = {
"monitor_type": "",
"key_signal_type": "",
"exchange_turnover_usdt": None,
"exchange_commission_usdt": None,
}
base.update(kwargs)
return SimpleNamespace(**base)
def _matches_all(row, segment_key):
return segment_key == "all"
def _matches_manual(row, segment_key):
if segment_key == "all":
return True
if segment_key == "manual":
return (row.monitor_type or "").strip() == "手动" and not (row.key_signal_type or "").strip()
return False
class TradeStatsCalendarLibTests(unittest.TestCase):
def test_groups_by_trading_day_and_segment(self):
pnls = [
(10.0, None, "2026-06-18", _row(monitor_type="手动")),
(-3.0, None, "2026-06-18", _row(monitor_type="手动")),
(5.0, None, "2026-06-19", _row(monitor_type="自动", key_signal_type="箱体突破")),
]
payload = build_trade_stats_calendar(
pnls,
2026,
6,
"manual",
_matches_manual,
reset_hour=8,
)
self.assertEqual(payload["month"], 6)
self.assertEqual(payload["month_open_count"], 2)
days = payload["days"]
self.assertIn("2026-06-18", days)
self.assertNotIn("2026-06-19", days)
self.assertEqual(days["2026-06-18"]["open_count"], 2)
self.assertAlmostEqual(days["2026-06-18"]["pnl_total"], 7.0)
def test_invalid_month_raises(self):
with self.assertRaises(ValueError):
build_trade_stats_calendar([], 2026, 13, "all", _matches_all)
def test_initial_calendar_uses_current_month(self):
pnls = [(2.5, None, "2026-06-20", _row())]
payload = build_initial_stats_calendar(
pnls,
datetime(2026, 6, 26, 12, 0),
_matches_all,
reset_hour=8,
)
self.assertEqual(payload["year"], 2026)
self.assertEqual(payload["month"], 6)
self.assertEqual(payload["month_open_count"], 1)
self.assertIn("2026-06-20", payload["days"])
def test_bootstrap_json_roundtrip(self):
pnls = [(2.5, None, "2026-06-20", _row())]
payload, raw = build_stats_calendar_bootstrap(
pnls,
datetime(2026, 6, 26, 12, 0),
_matches_all,
reset_hour=8,
)
self.assertIsNotNone(payload)
self.assertIsNotNone(raw)
self.assertIn('"month_open_count":1', raw.replace(" ", ""))
if __name__ == "__main__":
unittest.main()
import unittest
from types import SimpleNamespace
from datetime import datetime
from lib.trade.trade_stats_calendar_lib import (
build_initial_stats_calendar,
build_stats_calendar_bootstrap,
build_trade_stats_calendar,
)
def _row(**kwargs):
base = {
"monitor_type": "",
"key_signal_type": "",
"exchange_turnover_usdt": None,
"exchange_commission_usdt": None,
}
base.update(kwargs)
return SimpleNamespace(**base)
def _matches_all(row, segment_key):
return segment_key == "all"
def _matches_manual(row, segment_key):
if segment_key == "all":
return True
if segment_key == "manual":
return (row.monitor_type or "").strip() == "手动" and not (row.key_signal_type or "").strip()
return False
class TradeStatsCalendarLibTests(unittest.TestCase):
def test_groups_by_trading_day_and_segment(self):
pnls = [
(10.0, None, "2026-06-18", _row(monitor_type="手动")),
(-3.0, None, "2026-06-18", _row(monitor_type="手动")),
(5.0, None, "2026-06-19", _row(monitor_type="自动", key_signal_type="箱体突破")),
]
payload = build_trade_stats_calendar(
pnls,
2026,
6,
"manual",
_matches_manual,
reset_hour=8,
)
self.assertEqual(payload["month"], 6)
self.assertEqual(payload["month_open_count"], 2)
days = payload["days"]
self.assertIn("2026-06-18", days)
self.assertNotIn("2026-06-19", days)
self.assertEqual(days["2026-06-18"]["open_count"], 2)
self.assertAlmostEqual(days["2026-06-18"]["pnl_total"], 7.0)
def test_invalid_month_raises(self):
with self.assertRaises(ValueError):
build_trade_stats_calendar([], 2026, 13, "all", _matches_all)
def test_initial_calendar_uses_current_month(self):
pnls = [(2.5, None, "2026-06-20", _row())]
payload = build_initial_stats_calendar(
pnls,
datetime(2026, 6, 26, 12, 0),
_matches_all,
reset_hour=8,
)
self.assertEqual(payload["year"], 2026)
self.assertEqual(payload["month"], 6)
self.assertEqual(payload["month_open_count"], 1)
self.assertIn("2026-06-20", payload["days"])
def test_bootstrap_json_roundtrip(self):
pnls = [(2.5, None, "2026-06-20", _row())]
payload, raw = build_stats_calendar_bootstrap(
pnls,
datetime(2026, 6, 26, 12, 0),
_matches_all,
reset_hour=8,
)
self.assertIsNotNone(payload)
self.assertIsNotNone(raw)
self.assertIn('"month_open_count":1', raw.replace(" ", ""))
if __name__ == "__main__":
unittest.main()
+101 -101
View File
@@ -1,101 +1,101 @@
"""趋势回调运行中计划:实际成交价重算补仓表与金额盈亏比。"""
from __future__ import annotations
import json
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from strategy_snapshot_lib import attach_trend_dca_levels # noqa: E402
from strategy_trend_lib import ( # noqa: E402
calc_trend_plan_money_metrics,
trend_leg_display_price,
)
class TestTrendDcaEnrichFills(unittest.TestCase):
def _base_plan(self, **overrides):
plan = {
"direction": "long",
"stop_loss": 0.329,
"take_profit": 0.476,
"first_order_amount": 115,
"snapshot_available_usdt": 97.98,
"risk_percent": 5,
"contract_size": 1.0,
"grid_prices_json": json.dumps([0.3465, 0.343, 0.3395, 0.336, 0.3325]),
"leg_amounts_json": json.dumps([23, 23, 23, 23, 23]),
"dca_legs": 5,
"first_order_done": 1,
"legs_done": 0,
"avg_entry_price": 0.3537,
"order_amount_open": 115,
"target_order_amount": 230,
"leg_fill_prices_json": json.dumps([0.3537]),
}
plan.update(overrides)
return plan
def test_header_money_rr_not_price_rr(self):
plan = self._base_plan()
metrics = calc_trend_plan_money_metrics(plan)
self.assertAlmostEqual(metrics["risk_amount_u"], 4.899, places=2)
self.assertIsNotNone(metrics["money_rr"])
self.assertLess(metrics["money_rr"], 4.0)
def test_done_dca_uses_actual_fill_price(self):
plan = self._base_plan(
legs_done=1,
avg_entry_price=0.3512,
order_amount_open=138,
leg_fill_prices_json=json.dumps([0.3537, 0.3458]),
)
enriched = attach_trend_dca_levels(plan)
levels = enriched["dca_levels"]
self.assertEqual(len(levels), 6)
dca1 = levels[1]
self.assertEqual(dca1["status"], "done")
self.assertAlmostEqual(dca1["price"], 0.3458, places=4)
self.assertIsNotNone(dca1["avg_entry"])
self.assertIsNotNone(dca1["rr"])
dca2 = levels[2]
self.assertEqual(dca2["status"], "pending")
self.assertAlmostEqual(dca2["price"], 0.343, places=4)
def test_missing_dca_fills_use_grid_trigger_not_inferred_price(self):
"""缺补仓成交价时:触发价用计划网格,末档均价对齐头部,禁止反推离谱成交价。"""
plan = self._base_plan(
legs_done=2,
avg_entry_price=0.3507,
order_amount_open=161,
leg_fill_prices_json=json.dumps([0.3436]),
grid_prices_json=json.dumps([0.343, 0.343, 0.3395, 0.336, 0.3325]),
)
enriched = attach_trend_dca_levels(plan)
levels = enriched["dca_levels"]
dca1 = levels[1]
dca2 = levels[2]
self.assertEqual(dca1["status"], "done")
self.assertAlmostEqual(dca1["price"], 0.343, places=4)
self.assertEqual(dca2["status"], "done")
self.assertAlmostEqual(dca2["price"], 0.343, places=4)
self.assertAlmostEqual(dca2["avg_entry"], 0.3507, places=4)
self.assertLess(dca2["price"], 0.36)
def test_display_price_never_infers_from_target_avg(self):
"""四所共用:缺记录时只用网格,不因均价反推离谱触发价。"""
plan = self._base_plan(
legs_done=2,
avg_entry_price=0.3507,
leg_fill_prices_json=json.dumps([0.3436]),
grid_prices_json=json.dumps([0.343, 0.343, 0.3395, 0.336, 0.3325]),
)
self.assertAlmostEqual(trend_leg_display_price(plan, 2), 0.343, places=4)
self.assertLess(trend_leg_display_price(plan, 2), 0.36)
if __name__ == "__main__":
unittest.main()
"""趋势回调运行中计划:实际成交价重算补仓表与金额盈亏比。"""
from __future__ import annotations
import json
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from lib.strategy.strategy_snapshot_lib import attach_trend_dca_levels # noqa: E402
from lib.strategy.strategy_trend_lib import ( # noqa: E402
calc_trend_plan_money_metrics,
trend_leg_display_price,
)
class TestTrendDcaEnrichFills(unittest.TestCase):
def _base_plan(self, **overrides):
plan = {
"direction": "long",
"stop_loss": 0.329,
"take_profit": 0.476,
"first_order_amount": 115,
"snapshot_available_usdt": 97.98,
"risk_percent": 5,
"contract_size": 1.0,
"grid_prices_json": json.dumps([0.3465, 0.343, 0.3395, 0.336, 0.3325]),
"leg_amounts_json": json.dumps([23, 23, 23, 23, 23]),
"dca_legs": 5,
"first_order_done": 1,
"legs_done": 0,
"avg_entry_price": 0.3537,
"order_amount_open": 115,
"target_order_amount": 230,
"leg_fill_prices_json": json.dumps([0.3537]),
}
plan.update(overrides)
return plan
def test_header_money_rr_not_price_rr(self):
plan = self._base_plan()
metrics = calc_trend_plan_money_metrics(plan)
self.assertAlmostEqual(metrics["risk_amount_u"], 4.899, places=2)
self.assertIsNotNone(metrics["money_rr"])
self.assertLess(metrics["money_rr"], 4.0)
def test_done_dca_uses_actual_fill_price(self):
plan = self._base_plan(
legs_done=1,
avg_entry_price=0.3512,
order_amount_open=138,
leg_fill_prices_json=json.dumps([0.3537, 0.3458]),
)
enriched = attach_trend_dca_levels(plan)
levels = enriched["dca_levels"]
self.assertEqual(len(levels), 6)
dca1 = levels[1]
self.assertEqual(dca1["status"], "done")
self.assertAlmostEqual(dca1["price"], 0.3458, places=4)
self.assertIsNotNone(dca1["avg_entry"])
self.assertIsNotNone(dca1["rr"])
dca2 = levels[2]
self.assertEqual(dca2["status"], "pending")
self.assertAlmostEqual(dca2["price"], 0.343, places=4)
def test_missing_dca_fills_use_grid_trigger_not_inferred_price(self):
"""缺补仓成交价时:触发价用计划网格,末档均价对齐头部,禁止反推离谱成交价。"""
plan = self._base_plan(
legs_done=2,
avg_entry_price=0.3507,
order_amount_open=161,
leg_fill_prices_json=json.dumps([0.3436]),
grid_prices_json=json.dumps([0.343, 0.343, 0.3395, 0.336, 0.3325]),
)
enriched = attach_trend_dca_levels(plan)
levels = enriched["dca_levels"]
dca1 = levels[1]
dca2 = levels[2]
self.assertEqual(dca1["status"], "done")
self.assertAlmostEqual(dca1["price"], 0.343, places=4)
self.assertEqual(dca2["status"], "done")
self.assertAlmostEqual(dca2["price"], 0.343, places=4)
self.assertAlmostEqual(dca2["avg_entry"], 0.3507, places=4)
self.assertLess(dca2["price"], 0.36)
def test_display_price_never_infers_from_target_avg(self):
"""四所共用:缺记录时只用网格,不因均价反推离谱触发价。"""
plan = self._base_plan(
legs_done=2,
avg_entry_price=0.3507,
leg_fill_prices_json=json.dumps([0.3436]),
grid_prices_json=json.dumps([0.343, 0.343, 0.3395, 0.336, 0.3325]),
)
self.assertAlmostEqual(trend_leg_display_price(plan, 2), 0.343, places=4)
self.assertLess(trend_leg_display_price(plan, 2), 0.36)
if __name__ == "__main__":
unittest.main()
+43 -43
View File
@@ -1,43 +1,43 @@
"""趋势回调:补仓触达与有效保证金估算。"""
from strategy_trend_lib import trend_dca_level_reached, trend_effective_margin_capital
def test_trend_dca_short_monotonic_up_fills_missed_legs():
"""做空价升:旧逻辑需 last<level,价越过 0.3437 后 last 已高于该档则永不补仓。"""
direction = "short"
levels = [0.3413, 0.3437, 0.346, 0.3483, 0.3507]
pf = 0.353
filled = [lv for lv in levels if trend_dca_level_reached(direction, pf, lv)]
assert filled == levels
def test_trend_dca_short_not_before_first_level():
direction = "short"
assert not trend_dca_level_reached(direction, 0.336, 0.3413)
assert trend_dca_level_reached(direction, 0.3413, 0.3413)
def test_trend_dca_long_mark_below_trigger():
direction = "long"
assert trend_dca_level_reached(direction, 0.344, 0.3465)
assert not trend_dca_level_reached(direction, 0.347, 0.3465)
def test_trend_effective_margin_first_leg_only():
plan = {
"plan_margin_capital": 12.11,
"target_order_amount": 359.0,
"order_amount_open": 179.0,
"first_order_amount": 179.0,
}
m = trend_effective_margin_capital(plan)
assert abs(m - 12.11 * 179 / 359) < 0.01
def test_trend_effective_margin_full_position():
plan = {
"plan_margin_capital": 12.11,
"target_order_amount": 359.0,
"order_amount_open": 359.0,
}
assert trend_effective_margin_capital(plan) == 12.11
"""趋势回调:补仓触达与有效保证金估算。"""
from lib.strategy.strategy_trend_lib import trend_dca_level_reached, trend_effective_margin_capital
def test_trend_dca_short_monotonic_up_fills_missed_legs():
"""做空价升:旧逻辑需 last<level,价越过 0.3437 后 last 已高于该档则永不补仓。"""
direction = "short"
levels = [0.3413, 0.3437, 0.346, 0.3483, 0.3507]
pf = 0.353
filled = [lv for lv in levels if trend_dca_level_reached(direction, pf, lv)]
assert filled == levels
def test_trend_dca_short_not_before_first_level():
direction = "short"
assert not trend_dca_level_reached(direction, 0.336, 0.3413)
assert trend_dca_level_reached(direction, 0.3413, 0.3413)
def test_trend_dca_long_mark_below_trigger():
direction = "long"
assert trend_dca_level_reached(direction, 0.344, 0.3465)
assert not trend_dca_level_reached(direction, 0.347, 0.3465)
def test_trend_effective_margin_first_leg_only():
plan = {
"plan_margin_capital": 12.11,
"target_order_amount": 359.0,
"order_amount_open": 179.0,
"first_order_amount": 179.0,
}
m = trend_effective_margin_capital(plan)
assert abs(m - 12.11 * 179 / 359) < 0.01
def test_trend_effective_margin_full_position():
plan = {
"plan_margin_capital": 12.11,
"target_order_amount": 359.0,
"order_amount_open": 359.0,
}
assert trend_effective_margin_capital(plan) == 12.11
+92 -92
View File
@@ -1,92 +1,92 @@
"""趋势计划结束:须写入 trade_records(四所统一)。"""
from __future__ import annotations
import inspect
import sqlite3
import sys
import unittest
from pathlib import Path
from unittest.mock import MagicMock
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from strategy_trend_register import _call_insert_trade_record # noqa: E402
class _GateBotLikeModule:
"""模拟 gate_bot:曾有 trend_plan_id 但缺 entry_reason 参数。"""
@staticmethod
def insert_trade_record(
conn,
symbol,
monitor_type,
direction,
trigger_price,
stop_loss,
initial_stop_loss=None,
take_profit=None,
margin_capital=None,
leverage=None,
pnl_amount=0,
hold_seconds=0,
trade_style=None,
risk_amount=None,
planned_rr=None,
actual_rr=None,
result="",
miss_reason=None,
opened_at=None,
opened_at_ms=None,
closed_at=None,
closed_at_ms=None,
exchange_trade_id=None,
trend_plan_id=None,
):
conn.execute(
"INSERT INTO trade_records (symbol, monitor_type, direction, result, trend_plan_id) "
"VALUES (?,?,?,?,?)",
(symbol, monitor_type, direction, result, trend_plan_id),
)
class TestTrendFinalizeTradeRecord(unittest.TestCase):
def test_call_insert_filters_unknown_entry_reason(self):
conn = sqlite3.connect(":memory:")
conn.execute(
"CREATE TABLE trade_records (symbol TEXT, monitor_type TEXT, direction TEXT, "
"result TEXT, trend_plan_id INTEGER)"
)
m = _GateBotLikeModule()
_call_insert_trade_record(
m,
4,
dict(
conn=conn,
symbol="ONDO/USDT",
monitor_type="趋势回调",
direction="long",
trigger_price=0.35,
stop_loss=0.329,
result="止损",
entry_reason="趋势回调",
),
)
row = conn.execute(
"SELECT symbol, monitor_type, trend_plan_id FROM trade_records"
).fetchone()
self.assertEqual(row[0], "ONDO/USDT")
self.assertEqual(row[1], "趋势回调")
self.assertEqual(row[2], 4)
def test_gate_bot_insert_accepts_entry_reason(self):
from crypto_monitor_gate_bot import app as gate_bot_app # noqa: E402
sig = inspect.signature(gate_bot_app.insert_trade_record)
self.assertIn("entry_reason", sig.parameters)
self.assertIn("trend_plan_id", sig.parameters)
if __name__ == "__main__":
unittest.main()
"""趋势计划结束:须写入 trade_records(四所统一)。"""
from __future__ import annotations
import inspect
import sqlite3
import sys
import unittest
from pathlib import Path
from unittest.mock import MagicMock
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from lib.strategy.strategy_trend_register import _call_insert_trade_record # noqa: E402
class _GateBotLikeModule:
"""模拟 gate_bot:曾有 trend_plan_id 但缺 entry_reason 参数。"""
@staticmethod
def insert_trade_record(
conn,
symbol,
monitor_type,
direction,
trigger_price,
stop_loss,
initial_stop_loss=None,
take_profit=None,
margin_capital=None,
leverage=None,
pnl_amount=0,
hold_seconds=0,
trade_style=None,
risk_amount=None,
planned_rr=None,
actual_rr=None,
result="",
miss_reason=None,
opened_at=None,
opened_at_ms=None,
closed_at=None,
closed_at_ms=None,
exchange_trade_id=None,
trend_plan_id=None,
):
conn.execute(
"INSERT INTO trade_records (symbol, monitor_type, direction, result, trend_plan_id) "
"VALUES (?,?,?,?,?)",
(symbol, monitor_type, direction, result, trend_plan_id),
)
class TestTrendFinalizeTradeRecord(unittest.TestCase):
def test_call_insert_filters_unknown_entry_reason(self):
conn = sqlite3.connect(":memory:")
conn.execute(
"CREATE TABLE trade_records (symbol TEXT, monitor_type TEXT, direction TEXT, "
"result TEXT, trend_plan_id INTEGER)"
)
m = _GateBotLikeModule()
_call_insert_trade_record(
m,
4,
dict(
conn=conn,
symbol="ONDO/USDT",
monitor_type="趋势回调",
direction="long",
trigger_price=0.35,
stop_loss=0.329,
result="止损",
entry_reason="趋势回调",
),
)
row = conn.execute(
"SELECT symbol, monitor_type, trend_plan_id FROM trade_records"
).fetchone()
self.assertEqual(row[0], "ONDO/USDT")
self.assertEqual(row[1], "趋势回调")
self.assertEqual(row[2], 4)
def test_gate_bot_insert_accepts_entry_reason(self):
from crypto_monitor_gate_bot import app as gate_bot_app # noqa: E402
sig = inspect.signature(gate_bot_app.insert_trade_record)
self.assertIn("entry_reason", sig.parameters)
self.assertIn("trend_plan_id", sig.parameters)
if __name__ == "__main__":
unittest.main()
+44 -44
View File
@@ -1,44 +1,44 @@
"""趋势回调中控 enrich:补仓次数与加仓价。"""
from __future__ import annotations
import json
import sys
import unittest
from pathlib import Path
from unittest.mock import MagicMock
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from strategy_trend_register import _trend_add_leg_fields # noqa: E402
class TestTrendHubEnrich(unittest.TestCase):
def test_add_count_and_prices(self):
mock_ex = MagicMock()
mock_ex.price_to_precision = lambda sym, px: f"{float(px):.4f}"
app_mod = MagicMock()
app_mod.exchange = mock_ex
app_mod.ensure_markets_loaded = MagicMock()
app_mod.normalize_exchange_symbol = lambda s: s
cfg = {"app_module": app_mod}
raw = {
"symbol": "ETH/USDT",
"exchange_symbol": "ETH/USDT:USDT",
"legs_done": 2,
"dca_legs": 5,
"grid_prices_json": json.dumps([1800.1, 1750.2, 1700.3]),
"stop_loss": 1600,
"take_profit": 2000,
"avg_entry_price": 1820.5,
}
out = _trend_add_leg_fields(cfg, raw)
self.assertEqual(out["add_count"], 2)
self.assertEqual(out["add_count_total"], 5)
self.assertEqual(out["add_prices"], [1800.1, 1750.2])
self.assertEqual(len(out["add_prices_display"]), 2)
self.assertEqual(out["stop_loss_display"], "1600.0000")
if __name__ == "__main__":
unittest.main()
"""趋势回调中控 enrich:补仓次数与加仓价。"""
from __future__ import annotations
import json
import sys
import unittest
from pathlib import Path
from unittest.mock import MagicMock
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from lib.strategy.strategy_trend_register import _trend_add_leg_fields # noqa: E402
class TestTrendHubEnrich(unittest.TestCase):
def test_add_count_and_prices(self):
mock_ex = MagicMock()
mock_ex.price_to_precision = lambda sym, px: f"{float(px):.4f}"
app_mod = MagicMock()
app_mod.exchange = mock_ex
app_mod.ensure_markets_loaded = MagicMock()
app_mod.normalize_exchange_symbol = lambda s: s
cfg = {"app_module": app_mod}
raw = {
"symbol": "ETH/USDT",
"exchange_symbol": "ETH/USDT:USDT",
"legs_done": 2,
"dca_legs": 5,
"grid_prices_json": json.dumps([1800.1, 1750.2, 1700.3]),
"stop_loss": 1600,
"take_profit": 2000,
"avg_entry_price": 1820.5,
}
out = _trend_add_leg_fields(cfg, raw)
self.assertEqual(out["add_count"], 2)
self.assertEqual(out["add_count_total"], 5)
self.assertEqual(out["add_prices"], [1800.1, 1750.2])
self.assertEqual(len(out["add_prices_display"]), 2)
self.assertEqual(out["stop_loss_display"], "1600.0000")
if __name__ == "__main__":
unittest.main()
+92 -92
View File
@@ -1,92 +1,92 @@
"""四所趋势 enrich:实例与中控 monitor 字段一致。"""
from __future__ import annotations
import json
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from strategy_trend_register import ( # noqa: E402
enrich_trend_plan,
enrich_trend_plan_for_hub,
)
class _FakeModule:
@staticmethod
def normalize_exchange_symbol(sym):
return sym
@staticmethod
def ensure_markets_loaded():
return None
@staticmethod
def get_live_position_exchange_metrics(ex_sym, direction, order_leverage=None):
return {
"entry_price": 0.3507,
"mark_price": 0.3431,
"unrealized_pnl": -1.23,
}
@staticmethod
def get_contract_size(_ex_sym):
return 1.0
class TestTrendHubEnrichUnified(unittest.TestCase):
def _cfg(self):
return {
"app_module": _FakeModule(),
"breakeven_offset_pct": 0.3,
"row_to_dict": lambda row: dict(row) if not isinstance(row, dict) else row,
}
def _plan_row(self):
return {
"id": 4,
"symbol": "ONDO/USDT:USDT",
"exchange_symbol": "ONDO/USDT:USDT",
"direction": "long",
"stop_loss": 0.329,
"take_profit": 0.476,
"add_upper": 0.35,
"first_order_amount": 115,
"snapshot_available_usdt": 97.98,
"risk_percent": 5,
"contract_size": 1.0,
"grid_prices_json": json.dumps([0.343, 0.343, 0.3395, 0.336, 0.3325]),
"leg_amounts_json": json.dumps([23, 23, 23, 23, 23]),
"dca_legs": 5,
"first_order_done": 1,
"legs_done": 2,
"avg_entry_price": 0.3434,
"order_amount_open": 161,
"leg_fill_prices_json": json.dumps([0.3436]),
"leverage": 10,
"plan_margin_capital": 8.17,
}
def test_hub_and_page_share_live_avg_and_dca_levels(self):
cfg = self._cfg()
row = self._plan_row()
page = enrich_trend_plan(cfg, row)
hub = enrich_trend_plan_for_hub(cfg, row)
self.assertAlmostEqual(page["avg_entry_price"], 0.3507, places=4)
self.assertAlmostEqual(hub["avg_entry_price"], 0.3507, places=4)
self.assertIn("dca_levels", page)
self.assertIn("dca_levels", hub)
last_done = hub["dca_levels"][2]
self.assertEqual(last_done["status"], "done")
self.assertAlmostEqual(last_done["price"], 0.343, places=4)
self.assertAlmostEqual(last_done["avg_entry"], 0.3507, places=4)
self.assertLess(last_done["price"], 0.36)
self.assertEqual(hub.get("monitor_source"), "趋势回调计划")
self.assertEqual(hub.get("add_count"), 2)
if __name__ == "__main__":
unittest.main()
"""四所趋势 enrich:实例与中控 monitor 字段一致。"""
from __future__ import annotations
import json
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from lib.strategy.strategy_trend_register import ( # noqa: E402
enrich_trend_plan,
enrich_trend_plan_for_hub,
)
class _FakeModule:
@staticmethod
def normalize_exchange_symbol(sym):
return sym
@staticmethod
def ensure_markets_loaded():
return None
@staticmethod
def get_live_position_exchange_metrics(ex_sym, direction, order_leverage=None):
return {
"entry_price": 0.3507,
"mark_price": 0.3431,
"unrealized_pnl": -1.23,
}
@staticmethod
def get_contract_size(_ex_sym):
return 1.0
class TestTrendHubEnrichUnified(unittest.TestCase):
def _cfg(self):
return {
"app_module": _FakeModule(),
"breakeven_offset_pct": 0.3,
"row_to_dict": lambda row: dict(row) if not isinstance(row, dict) else row,
}
def _plan_row(self):
return {
"id": 4,
"symbol": "ONDO/USDT:USDT",
"exchange_symbol": "ONDO/USDT:USDT",
"direction": "long",
"stop_loss": 0.329,
"take_profit": 0.476,
"add_upper": 0.35,
"first_order_amount": 115,
"snapshot_available_usdt": 97.98,
"risk_percent": 5,
"contract_size": 1.0,
"grid_prices_json": json.dumps([0.343, 0.343, 0.3395, 0.336, 0.3325]),
"leg_amounts_json": json.dumps([23, 23, 23, 23, 23]),
"dca_legs": 5,
"first_order_done": 1,
"legs_done": 2,
"avg_entry_price": 0.3434,
"order_amount_open": 161,
"leg_fill_prices_json": json.dumps([0.3436]),
"leverage": 10,
"plan_margin_capital": 8.17,
}
def test_hub_and_page_share_live_avg_and_dca_levels(self):
cfg = self._cfg()
row = self._plan_row()
page = enrich_trend_plan(cfg, row)
hub = enrich_trend_plan_for_hub(cfg, row)
self.assertAlmostEqual(page["avg_entry_price"], 0.3507, places=4)
self.assertAlmostEqual(hub["avg_entry_price"], 0.3507, places=4)
self.assertIn("dca_levels", page)
self.assertIn("dca_levels", hub)
last_done = hub["dca_levels"][2]
self.assertEqual(last_done["status"], "done")
self.assertAlmostEqual(last_done["price"], 0.343, places=4)
self.assertAlmostEqual(last_done["avg_entry"], 0.3507, places=4)
self.assertLess(last_done["price"], 0.36)
self.assertEqual(hub.get("monitor_source"), "趋势回调计划")
self.assertEqual(hub.get("add_count"), 2)
if __name__ == "__main__":
unittest.main()
+41 -41
View File
@@ -1,41 +1,41 @@
"""趋势补仓下单:空 params 不得变成 Noneccxt 会报 not iterable)。"""
from __future__ import annotations
import unittest
from unittest.mock import MagicMock
from strategy_trend_exchange import trend_market_add
class TestTrendMarketAddParams(unittest.TestCase):
def test_empty_gate_params_not_passed_as_none(self):
ex = MagicMock()
ex.create_order.return_value = {"id": "1", "average": 0.34}
app = MagicMock()
app.exchange = ex
app.ensure_markets_loaded = MagicMock()
app.build_gate_order_params = MagicMock(return_value={})
cfg = {"app_module": app}
trend_market_add(cfg, "ONDO/USDT:USDT", "long", 23, 10)
args = ex.create_order.call_args
self.assertEqual(args[0][5], {})
self.assertIsNotNone(args[0][5])
def test_binance_oneway_empty_params_not_passed_as_none(self):
ex = MagicMock()
ex.create_order.return_value = {"id": "1"}
app = MagicMock(spec=["exchange", "ensure_markets_loaded", "build_binance_order_params"])
app.exchange = ex
app.ensure_markets_loaded = MagicMock()
app.build_binance_order_params = MagicMock(return_value={})
cfg = {"app_module": app}
trend_market_add(cfg, "BTC/USDT:USDT", "long", 1, 10)
self.assertEqual(ex.create_order.call_args[0][5], {})
if __name__ == "__main__":
unittest.main()
"""趋势补仓下单:空 params 不得变成 Noneccxt 会报 not iterable)。"""
from __future__ import annotations
import unittest
from unittest.mock import MagicMock
from lib.strategy.strategy_trend_exchange import trend_market_add
class TestTrendMarketAddParams(unittest.TestCase):
def test_empty_gate_params_not_passed_as_none(self):
ex = MagicMock()
ex.create_order.return_value = {"id": "1", "average": 0.34}
app = MagicMock()
app.exchange = ex
app.ensure_markets_loaded = MagicMock()
app.build_gate_order_params = MagicMock(return_value={})
cfg = {"app_module": app}
trend_market_add(cfg, "ONDO/USDT:USDT", "long", 23, 10)
args = ex.create_order.call_args
self.assertEqual(args[0][5], {})
self.assertIsNotNone(args[0][5])
def test_binance_oneway_empty_params_not_passed_as_none(self):
ex = MagicMock()
ex.create_order.return_value = {"id": "1"}
app = MagicMock(spec=["exchange", "ensure_markets_loaded", "build_binance_order_params"])
app.exchange = ex
app.ensure_markets_loaded = MagicMock()
app.build_binance_order_params = MagicMock(return_value={})
cfg = {"app_module": app}
trend_market_add(cfg, "BTC/USDT:USDT", "long", 1, 10)
self.assertEqual(ex.create_order.call_args[0][5], {})
if __name__ == "__main__":
unittest.main()
+58 -58
View File
@@ -1,58 +1,58 @@
"""趋势回调预览:止盈盈利 U、止损金额 U、金额盈亏比。"""
from __future__ import annotations
import json
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from strategy_trend_lib import ( # noqa: E402
build_trend_preview_level_rows,
calc_money_reward_risk_ratio,
calc_risk_budget_usdt,
calc_tp_profit_usdt,
)
class TestTrendPreviewTp(unittest.TestCase):
def test_risk_budget_from_snapshot(self):
self.assertAlmostEqual(calc_risk_budget_usdt(110.73, 5), 5.5365, places=2)
def test_short_profit_at_form_take_profit(self):
profit = calc_tp_profit_usdt("short", 72.53, 66.0, 1114, 0.00167)
self.assertIsNotNone(profit)
self.assertGreater(profit, 0)
rr = calc_money_reward_risk_ratio(profit, 5.5365)
self.assertIsNotNone(rr)
self.assertGreater(rr, 1.5)
def test_preview_levels_use_money_rr(self):
preview = {
"direction": "short",
"live_price_ref": 72.53,
"stop_loss": 75.5,
"take_profit": 66.0,
"first_order_amount": 1114,
"snapshot_available_usdt": 110.73,
"risk_percent": 5,
"contract_size": 0.00167,
"grid_prices_json": json.dumps([73.42, 73.83]),
"leg_amounts_json": json.dumps([222, 222]),
}
enriched, rows = build_trend_preview_level_rows(preview)
self.assertAlmostEqual(enriched["preview_risk_amount_u"], 5.5365, places=2)
self.assertEqual(enriched["preview_take_profit_price"], 66.0)
self.assertEqual(len(rows), 3)
self.assertEqual(rows[0]["label"], "首仓")
self.assertEqual(rows[0]["risk_u"], enriched["preview_risk_amount_u"])
self.assertIsNotNone(rows[0]["profit_u"])
self.assertAlmostEqual(rows[0]["rr"], rows[0]["profit_u"] / 5.5365, places=2)
self.assertEqual(rows[1]["risk_u"], enriched["preview_risk_amount_u"])
self.assertGreater(rows[2]["profit_u"], rows[1]["profit_u"])
if __name__ == "__main__":
unittest.main()
"""趋势回调预览:止盈盈利 U、止损金额 U、金额盈亏比。"""
from __future__ import annotations
import json
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from lib.strategy.strategy_trend_lib import ( # noqa: E402
build_trend_preview_level_rows,
calc_money_reward_risk_ratio,
calc_risk_budget_usdt,
calc_tp_profit_usdt,
)
class TestTrendPreviewTp(unittest.TestCase):
def test_risk_budget_from_snapshot(self):
self.assertAlmostEqual(calc_risk_budget_usdt(110.73, 5), 5.5365, places=2)
def test_short_profit_at_form_take_profit(self):
profit = calc_tp_profit_usdt("short", 72.53, 66.0, 1114, 0.00167)
self.assertIsNotNone(profit)
self.assertGreater(profit, 0)
rr = calc_money_reward_risk_ratio(profit, 5.5365)
self.assertIsNotNone(rr)
self.assertGreater(rr, 1.5)
def test_preview_levels_use_money_rr(self):
preview = {
"direction": "short",
"live_price_ref": 72.53,
"stop_loss": 75.5,
"take_profit": 66.0,
"first_order_amount": 1114,
"snapshot_available_usdt": 110.73,
"risk_percent": 5,
"contract_size": 0.00167,
"grid_prices_json": json.dumps([73.42, 73.83]),
"leg_amounts_json": json.dumps([222, 222]),
}
enriched, rows = build_trend_preview_level_rows(preview)
self.assertAlmostEqual(enriched["preview_risk_amount_u"], 5.5365, places=2)
self.assertEqual(enriched["preview_take_profit_price"], 66.0)
self.assertEqual(len(rows), 3)
self.assertEqual(rows[0]["label"], "首仓")
self.assertEqual(rows[0]["risk_u"], enriched["preview_risk_amount_u"])
self.assertIsNotNone(rows[0]["profit_u"])
self.assertAlmostEqual(rows[0]["rr"], rows[0]["profit_u"] / 5.5365, places=2)
self.assertEqual(rows[1]["risk_u"], enriched["preview_risk_amount_u"])
self.assertGreater(rows[2]["profit_u"], rows[1]["profit_u"])
if __name__ == "__main__":
unittest.main()
+85 -85
View File
@@ -1,85 +1,85 @@
"""触价开仓(回调/突破)关键位监控单元测试。"""
from trigger_entry_key_monitor_lib import (
BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE,
CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE,
LEGACY_TRIGGER_ENTRY_MONITOR_TYPE,
TRIGGER_ENTRY_MONITOR_TYPES,
TRIGGER_ENTRY_VALIDITY_HOURS,
breakout_trigger_entry_crossed,
check_trigger_entry_intent_limit,
is_breakout_trigger_entry_key_monitor_type,
is_trigger_entry_key_monitor_type,
trigger_entry_invalidate,
trigger_entry_reached,
trigger_should_fire,
validate_trigger_entry_geometry,
)
class _FakeConn:
def execute(self, sql, params=()):
class R:
def fetchone(self_inner):
return (2,)
return R()
def test_trigger_entry_reached_long():
assert trigger_entry_reached("long", 2049.0, 2050.0) is True
assert trigger_entry_reached("long", 2051.0, 2050.0) is False
def test_breakout_cross_long_up():
assert breakout_trigger_entry_crossed("long", 99.0, 100.5, 100.0) is True
assert breakout_trigger_entry_crossed("long", None, 101.0, 100.0) is True
assert breakout_trigger_entry_crossed("long", 100.0, 100.0, 100.0) is False
def test_breakout_cross_short_down():
assert breakout_trigger_entry_crossed("short", 101.0, 99.5, 100.0) is True
assert breakout_trigger_entry_crossed("short", None, 99.0, 100.0) is True
def test_trigger_should_fire_modes():
assert trigger_should_fire(CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE, "long", 2049.0, 2050.0) is True
assert trigger_should_fire(BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE, "long", 100.5, 100.0, 99.0) is True
def test_validate_geometry_callback_long():
assert validate_trigger_entry_geometry("long", 2050, 2000, 2100, 2090) is None
def test_validate_geometry_breakout_short_requires_mark_above_entry():
assert (
validate_trigger_entry_geometry(
"short", 551, 568, 540, 560, monitor_type=BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE
)
is None
)
err = validate_trigger_entry_geometry(
"short", 551, 568, 540, 550, monitor_type=BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE
)
assert err is not None
assert "高于入场价" in err
def test_invalidate_breakout_sl_side():
assert trigger_entry_invalidate(BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE, "long", 96, 97, 110) == "sl"
assert trigger_entry_invalidate(CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE, "long", 96, 97, 110) is None
def test_intent_limit():
ok, msg = check_trigger_entry_intent_limit(_FakeConn(), "2026-06-07", 2, 3)
assert ok is False
assert "意图" in msg
def test_type_names():
assert is_trigger_entry_key_monitor_type(CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE)
assert is_trigger_entry_key_monitor_type(BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE)
assert is_trigger_entry_key_monitor_type(LEGACY_TRIGGER_ENTRY_MONITOR_TYPE)
assert is_breakout_trigger_entry_key_monitor_type(BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE)
assert CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE in TRIGGER_ENTRY_MONITOR_TYPES
assert BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE in TRIGGER_ENTRY_MONITOR_TYPES
assert TRIGGER_ENTRY_VALIDITY_HOURS == 24
"""触价开仓(回调/突破)关键位监控单元测试。"""
from lib.key_monitor.trigger_entry_key_monitor_lib import (
BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE,
CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE,
LEGACY_TRIGGER_ENTRY_MONITOR_TYPE,
TRIGGER_ENTRY_MONITOR_TYPES,
TRIGGER_ENTRY_VALIDITY_HOURS,
breakout_trigger_entry_crossed,
check_trigger_entry_intent_limit,
is_breakout_trigger_entry_key_monitor_type,
is_trigger_entry_key_monitor_type,
trigger_entry_invalidate,
trigger_entry_reached,
trigger_should_fire,
validate_trigger_entry_geometry,
)
class _FakeConn:
def execute(self, sql, params=()):
class R:
def fetchone(self_inner):
return (2,)
return R()
def test_trigger_entry_reached_long():
assert trigger_entry_reached("long", 2049.0, 2050.0) is True
assert trigger_entry_reached("long", 2051.0, 2050.0) is False
def test_breakout_cross_long_up():
assert breakout_trigger_entry_crossed("long", 99.0, 100.5, 100.0) is True
assert breakout_trigger_entry_crossed("long", None, 101.0, 100.0) is True
assert breakout_trigger_entry_crossed("long", 100.0, 100.0, 100.0) is False
def test_breakout_cross_short_down():
assert breakout_trigger_entry_crossed("short", 101.0, 99.5, 100.0) is True
assert breakout_trigger_entry_crossed("short", None, 99.0, 100.0) is True
def test_trigger_should_fire_modes():
assert trigger_should_fire(CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE, "long", 2049.0, 2050.0) is True
assert trigger_should_fire(BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE, "long", 100.5, 100.0, 99.0) is True
def test_validate_geometry_callback_long():
assert validate_trigger_entry_geometry("long", 2050, 2000, 2100, 2090) is None
def test_validate_geometry_breakout_short_requires_mark_above_entry():
assert (
validate_trigger_entry_geometry(
"short", 551, 568, 540, 560, monitor_type=BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE
)
is None
)
err = validate_trigger_entry_geometry(
"short", 551, 568, 540, 550, monitor_type=BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE
)
assert err is not None
assert "高于入场价" in err
def test_invalidate_breakout_sl_side():
assert trigger_entry_invalidate(BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE, "long", 96, 97, 110) == "sl"
assert trigger_entry_invalidate(CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE, "long", 96, 97, 110) is None
def test_intent_limit():
ok, msg = check_trigger_entry_intent_limit(_FakeConn(), "2026-06-07", 2, 3)
assert ok is False
assert "意图" in msg
def test_type_names():
assert is_trigger_entry_key_monitor_type(CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE)
assert is_trigger_entry_key_monitor_type(BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE)
assert is_trigger_entry_key_monitor_type(LEGACY_TRIGGER_ENTRY_MONITOR_TYPE)
assert is_breakout_trigger_entry_key_monitor_type(BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE)
assert CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE in TRIGGER_ENTRY_MONITOR_TYPES
assert BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE in TRIGGER_ENTRY_MONITOR_TYPES
assert TRIGGER_ENTRY_VALIDITY_HOURS == 24