refactor: 将共用代码迁入 lib/ 模块化目录
统一 strategy、key_monitor、trade、hub 等共用库到 lib/ 子包,并补充 lib-structure 文档,便于四所与中控维护。 Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
+526
-526
File diff suppressed because it is too large
Load Diff
+63
-63
@@ -1,63 +1,63 @@
|
||||
"""AI 复盘 journal 文本格式化(四所共用)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from ai_review_lib import journal_row_lines_for_ai # noqa: E402
|
||||
|
||||
|
||||
class TestAiReviewLib(unittest.TestCase):
|
||||
def test_journal_row_includes_expect_and_actual_rr(self):
|
||||
text = journal_row_lines_for_ai(
|
||||
1,
|
||||
{
|
||||
"coin": "HYPE",
|
||||
"tf": "5m",
|
||||
"pnl": "10.73",
|
||||
"real_rr": "2.1354",
|
||||
"expect_rr": "-",
|
||||
"entry_reason": "趋势回调",
|
||||
"exit_reason": "移动止盈",
|
||||
"hold_duration": "1天 3小时",
|
||||
"mood_issues": "",
|
||||
"post_breakeven_stare": "否",
|
||||
"new_trade_while_occupied": "否",
|
||||
"note": "测试备注",
|
||||
},
|
||||
)
|
||||
self.assertIn("实际RR:2.1354", text)
|
||||
self.assertIn("预期RR:-", text)
|
||||
self.assertIn("开仓逻辑:趋势回调", text)
|
||||
self.assertIn("备注:测试备注", text)
|
||||
self.assertNotIn("开仓类型", text)
|
||||
|
||||
def test_journal_row_accepts_sqlite_row(self):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute(
|
||||
"""CREATE TABLE journal_entries (
|
||||
coin TEXT, tf TEXT, pnl TEXT, real_rr TEXT, expect_rr TEXT,
|
||||
entry_reason TEXT, exit_reason TEXT, hold_duration TEXT,
|
||||
mood_issues TEXT, mood_score INTEGER, note TEXT
|
||||
)"""
|
||||
)
|
||||
conn.execute(
|
||||
"""INSERT INTO journal_entries VALUES (?,?,?,?,?,?,?,?,?,?,?)""",
|
||||
("BTC", "15m", "5", "1.2", "2.0", "突破", "止盈", "2小时", "", None, ""),
|
||||
)
|
||||
row = conn.execute("SELECT * FROM journal_entries").fetchone()
|
||||
conn.close()
|
||||
text = journal_row_lines_for_ai(1, row)
|
||||
self.assertIn("BTC 15m", text)
|
||||
self.assertIn("实际RR:1.2", text)
|
||||
self.assertIn("开仓逻辑:突破", text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""AI 复盘 journal 文本格式化(四所共用)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from lib.ai.ai_review_lib import journal_row_lines_for_ai # noqa: E402
|
||||
|
||||
|
||||
class TestAiReviewLib(unittest.TestCase):
|
||||
def test_journal_row_includes_expect_and_actual_rr(self):
|
||||
text = journal_row_lines_for_ai(
|
||||
1,
|
||||
{
|
||||
"coin": "HYPE",
|
||||
"tf": "5m",
|
||||
"pnl": "10.73",
|
||||
"real_rr": "2.1354",
|
||||
"expect_rr": "-",
|
||||
"entry_reason": "趋势回调",
|
||||
"exit_reason": "移动止盈",
|
||||
"hold_duration": "1天 3小时",
|
||||
"mood_issues": "",
|
||||
"post_breakeven_stare": "否",
|
||||
"new_trade_while_occupied": "否",
|
||||
"note": "测试备注",
|
||||
},
|
||||
)
|
||||
self.assertIn("实际RR:2.1354", text)
|
||||
self.assertIn("预期RR:-", text)
|
||||
self.assertIn("开仓逻辑:趋势回调", text)
|
||||
self.assertIn("备注:测试备注", text)
|
||||
self.assertNotIn("开仓类型", text)
|
||||
|
||||
def test_journal_row_accepts_sqlite_row(self):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute(
|
||||
"""CREATE TABLE journal_entries (
|
||||
coin TEXT, tf TEXT, pnl TEXT, real_rr TEXT, expect_rr TEXT,
|
||||
entry_reason TEXT, exit_reason TEXT, hold_duration TEXT,
|
||||
mood_issues TEXT, mood_score INTEGER, note TEXT
|
||||
)"""
|
||||
)
|
||||
conn.execute(
|
||||
"""INSERT INTO journal_entries VALUES (?,?,?,?,?,?,?,?,?,?,?)""",
|
||||
("BTC", "15m", "5", "1.2", "2.0", "突破", "止盈", "2小时", "", None, ""),
|
||||
)
|
||||
row = conn.execute("SELECT * FROM journal_entries").fetchone()
|
||||
conn.close()
|
||||
text = journal_row_lines_for_ai(1, row)
|
||||
self.assertIn("BTC 15m", text)
|
||||
self.assertIn("实际RR:1.2", text)
|
||||
self.assertIn("开仓逻辑:突破", text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,60 +1,60 @@
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from hub_symbol_archive_lib import init_db, list_archive_calendar, upsert_trades_cache, upsert_trade_overlay
|
||||
|
||||
|
||||
def _bj_ms(y, m, d, hh, mm):
|
||||
dt = datetime(y, m, d, hh, mm, 0, tzinfo=ZoneInfo("Asia/Shanghai"))
|
||||
return int(dt.timestamp() * 1000)
|
||||
|
||||
|
||||
class ArchiveCalendarTests(unittest.TestCase):
|
||||
def test_calendar_groups_by_trading_day_and_sick(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "arch.db"
|
||||
init_db(db)
|
||||
upsert_trades_cache(
|
||||
"binance",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"symbol": "BTC/USDT",
|
||||
"direction": "long",
|
||||
"result": "止盈",
|
||||
"pnl_amount": 10.0,
|
||||
"opened_at": "2026-06-18 09:00:00",
|
||||
"closed_at": "2026-06-18 10:00:00",
|
||||
"closed_at_ms": _bj_ms(2026, 6, 18, 10, 0),
|
||||
"exchange_turnover_usdt": 2000.0,
|
||||
"exchange_commission_usdt": 0.8,
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"symbol": "ETH/USDT",
|
||||
"direction": "short",
|
||||
"result": "止损",
|
||||
"pnl_amount": -5.0,
|
||||
"opened_at": "2026-06-18 14:00:00",
|
||||
"closed_at": "2026-06-18 15:00:00",
|
||||
"closed_at_ms": _bj_ms(2026, 6, 18, 15, 0),
|
||||
},
|
||||
],
|
||||
db_path=db,
|
||||
)
|
||||
upsert_trade_overlay("binance", 2, behavior_tag="sick", db_path=db)
|
||||
payload = list_archive_calendar(2026, 6, db_path=db)
|
||||
self.assertEqual(payload["month"], 6)
|
||||
days = payload["days"]
|
||||
self.assertTrue(days)
|
||||
sick_days = [d for d in days.values() if d.get("has_sick")]
|
||||
self.assertTrue(sick_days)
|
||||
self.assertGreaterEqual(payload["month_open_count"], 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from lib.hub.hub_symbol_archive_lib import init_db, list_archive_calendar, upsert_trades_cache, upsert_trade_overlay
|
||||
|
||||
|
||||
def _bj_ms(y, m, d, hh, mm):
|
||||
dt = datetime(y, m, d, hh, mm, 0, tzinfo=ZoneInfo("Asia/Shanghai"))
|
||||
return int(dt.timestamp() * 1000)
|
||||
|
||||
|
||||
class ArchiveCalendarTests(unittest.TestCase):
|
||||
def test_calendar_groups_by_trading_day_and_sick(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "arch.db"
|
||||
init_db(db)
|
||||
upsert_trades_cache(
|
||||
"binance",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"symbol": "BTC/USDT",
|
||||
"direction": "long",
|
||||
"result": "止盈",
|
||||
"pnl_amount": 10.0,
|
||||
"opened_at": "2026-06-18 09:00:00",
|
||||
"closed_at": "2026-06-18 10:00:00",
|
||||
"closed_at_ms": _bj_ms(2026, 6, 18, 10, 0),
|
||||
"exchange_turnover_usdt": 2000.0,
|
||||
"exchange_commission_usdt": 0.8,
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"symbol": "ETH/USDT",
|
||||
"direction": "short",
|
||||
"result": "止损",
|
||||
"pnl_amount": -5.0,
|
||||
"opened_at": "2026-06-18 14:00:00",
|
||||
"closed_at": "2026-06-18 15:00:00",
|
||||
"closed_at_ms": _bj_ms(2026, 6, 18, 15, 0),
|
||||
},
|
||||
],
|
||||
db_path=db,
|
||||
)
|
||||
upsert_trade_overlay("binance", 2, behavior_tag="sick", db_path=db)
|
||||
payload = list_archive_calendar(2026, 6, db_path=db)
|
||||
self.assertEqual(payload["month"], 6)
|
||||
days = payload["days"]
|
||||
self.assertTrue(days)
|
||||
sick_days = [d for d in days.values() if d.get("has_sick")]
|
||||
self.assertTrue(sick_days)
|
||||
self.assertGreaterEqual(payload["month_open_count"], 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,90 +1,90 @@
|
||||
import unittest
|
||||
|
||||
from daily_open_limit_lib import (
|
||||
build_daily_open_alert_prompt,
|
||||
can_trade_new_open,
|
||||
check_daily_open_hard_limit,
|
||||
count_opens_for_trading_day,
|
||||
daily_open_hard_limit_blocks,
|
||||
format_daily_open_counter_line,
|
||||
hard_limit_block_reason,
|
||||
load_daily_open_limits_from_env,
|
||||
parse_daily_open_hard_limit,
|
||||
should_send_daily_open_alert,
|
||||
)
|
||||
|
||||
|
||||
class _FakeConn:
|
||||
def __init__(self, count: int):
|
||||
self._count = count
|
||||
|
||||
def execute(self, _sql, _params):
|
||||
return self
|
||||
|
||||
def fetchone(self):
|
||||
return (self._count,)
|
||||
|
||||
|
||||
class DailyOpenLimitLibTests(unittest.TestCase):
|
||||
def test_parse_hard_limit_zero_disables(self):
|
||||
self.assertEqual(parse_daily_open_hard_limit("0"), 0)
|
||||
self.assertEqual(parse_daily_open_hard_limit(None, default=0), 0)
|
||||
|
||||
def test_load_from_env(self):
|
||||
alert, hard = load_daily_open_limits_from_env(
|
||||
{"DAILY_OPEN_ALERT_THRESHOLD": "3", "DAILY_OPEN_HARD_LIMIT": "8"}
|
||||
)
|
||||
self.assertEqual(alert, 3)
|
||||
self.assertEqual(hard, 8)
|
||||
|
||||
def test_hard_limit_blocks(self):
|
||||
self.assertFalse(daily_open_hard_limit_blocks(4, 0))
|
||||
self.assertFalse(daily_open_hard_limit_blocks(4, 5))
|
||||
self.assertTrue(daily_open_hard_limit_blocks(5, 5))
|
||||
|
||||
def test_check_daily_open_hard_limit(self):
|
||||
conn = _FakeConn(5)
|
||||
ok, reason, n = check_daily_open_hard_limit(conn, "2026-06-07", 5, 8)
|
||||
self.assertFalse(ok)
|
||||
self.assertEqual(n, 5)
|
||||
self.assertIn("已达上限", reason)
|
||||
self.assertIn("8:00", reason)
|
||||
|
||||
def test_count_opens(self):
|
||||
self.assertEqual(count_opens_for_trading_day(_FakeConn(3), "2026-06-07"), 3)
|
||||
|
||||
def test_can_trade_new_open(self):
|
||||
self.assertTrue(
|
||||
can_trade_new_open(
|
||||
time_allows=True,
|
||||
active_count=0,
|
||||
max_active_positions=1,
|
||||
opens_today=2,
|
||||
hard_limit=5,
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
can_trade_new_open(
|
||||
time_allows=True,
|
||||
active_count=0,
|
||||
max_active_positions=1,
|
||||
opens_today=5,
|
||||
hard_limit=5,
|
||||
)
|
||||
)
|
||||
|
||||
def test_alert_crossing(self):
|
||||
self.assertTrue(should_send_daily_open_alert(4, 5, 5))
|
||||
self.assertFalse(should_send_daily_open_alert(5, 6, 5))
|
||||
|
||||
def test_prompt_includes_hard_limit(self):
|
||||
txt = build_daily_open_alert_prompt("2026-06-07", 5, 5, hard_limit=8)
|
||||
self.assertIn("硬上限 8", txt)
|
||||
|
||||
def test_counter_line(self):
|
||||
line = format_daily_open_counter_line(3, 5, 8)
|
||||
self.assertIn("3 / 硬上限 8", line)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
import unittest
|
||||
|
||||
from lib.trade.daily_open_limit_lib import (
|
||||
build_daily_open_alert_prompt,
|
||||
can_trade_new_open,
|
||||
check_daily_open_hard_limit,
|
||||
count_opens_for_trading_day,
|
||||
daily_open_hard_limit_blocks,
|
||||
format_daily_open_counter_line,
|
||||
hard_limit_block_reason,
|
||||
load_daily_open_limits_from_env,
|
||||
parse_daily_open_hard_limit,
|
||||
should_send_daily_open_alert,
|
||||
)
|
||||
|
||||
|
||||
class _FakeConn:
|
||||
def __init__(self, count: int):
|
||||
self._count = count
|
||||
|
||||
def execute(self, _sql, _params):
|
||||
return self
|
||||
|
||||
def fetchone(self):
|
||||
return (self._count,)
|
||||
|
||||
|
||||
class DailyOpenLimitLibTests(unittest.TestCase):
|
||||
def test_parse_hard_limit_zero_disables(self):
|
||||
self.assertEqual(parse_daily_open_hard_limit("0"), 0)
|
||||
self.assertEqual(parse_daily_open_hard_limit(None, default=0), 0)
|
||||
|
||||
def test_load_from_env(self):
|
||||
alert, hard = load_daily_open_limits_from_env(
|
||||
{"DAILY_OPEN_ALERT_THRESHOLD": "3", "DAILY_OPEN_HARD_LIMIT": "8"}
|
||||
)
|
||||
self.assertEqual(alert, 3)
|
||||
self.assertEqual(hard, 8)
|
||||
|
||||
def test_hard_limit_blocks(self):
|
||||
self.assertFalse(daily_open_hard_limit_blocks(4, 0))
|
||||
self.assertFalse(daily_open_hard_limit_blocks(4, 5))
|
||||
self.assertTrue(daily_open_hard_limit_blocks(5, 5))
|
||||
|
||||
def test_check_daily_open_hard_limit(self):
|
||||
conn = _FakeConn(5)
|
||||
ok, reason, n = check_daily_open_hard_limit(conn, "2026-06-07", 5, 8)
|
||||
self.assertFalse(ok)
|
||||
self.assertEqual(n, 5)
|
||||
self.assertIn("已达上限", reason)
|
||||
self.assertIn("8:00", reason)
|
||||
|
||||
def test_count_opens(self):
|
||||
self.assertEqual(count_opens_for_trading_day(_FakeConn(3), "2026-06-07"), 3)
|
||||
|
||||
def test_can_trade_new_open(self):
|
||||
self.assertTrue(
|
||||
can_trade_new_open(
|
||||
time_allows=True,
|
||||
active_count=0,
|
||||
max_active_positions=1,
|
||||
opens_today=2,
|
||||
hard_limit=5,
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
can_trade_new_open(
|
||||
time_allows=True,
|
||||
active_count=0,
|
||||
max_active_positions=1,
|
||||
opens_today=5,
|
||||
hard_limit=5,
|
||||
)
|
||||
)
|
||||
|
||||
def test_alert_crossing(self):
|
||||
self.assertTrue(should_send_daily_open_alert(4, 5, 5))
|
||||
self.assertFalse(should_send_daily_open_alert(5, 6, 5))
|
||||
|
||||
def test_prompt_includes_hard_limit(self):
|
||||
txt = build_daily_open_alert_prompt("2026-06-07", 5, 5, hard_limit=8)
|
||||
self.assertIn("硬上限 8", txt)
|
||||
|
||||
def test_counter_line(self):
|
||||
line = format_daily_open_counter_line(3, 5, 8)
|
||||
self.assertIn("3 / 硬上限 8", line)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,76 +1,76 @@
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from false_breakout_key_monitor_lib import (
|
||||
FALSE_BREAKOUT_MONITOR_TYPE,
|
||||
calc_false_breakout_plan,
|
||||
false_breakout_gate_preview,
|
||||
is_false_breakout_expired,
|
||||
key_price_from_row,
|
||||
normalize_false_breakout_symbol,
|
||||
storage_bounds_from_key_price,
|
||||
)
|
||||
|
||||
|
||||
class FalseBreakoutKeyMonitorLibTests(unittest.TestCase):
|
||||
def test_normalize_symbol(self):
|
||||
self.assertEqual(normalize_false_breakout_symbol("btc"), "BTC/USDT")
|
||||
self.assertEqual(normalize_false_breakout_symbol("ETH/USDT"), "ETH/USDT")
|
||||
self.assertIsNone(normalize_false_breakout_symbol("SOL"))
|
||||
|
||||
def test_short_plan(self):
|
||||
plan = calc_false_breakout_plan("short", 100000)
|
||||
self.assertIsNotNone(plan)
|
||||
entry, sl, tp = plan
|
||||
self.assertAlmostEqual(entry, 100100.0)
|
||||
self.assertAlmostEqual(sl, 100600.5)
|
||||
self.assertAlmostEqual(tp, 99349.25)
|
||||
|
||||
def test_long_plan(self):
|
||||
plan = calc_false_breakout_plan("long", 100000)
|
||||
self.assertIsNotNone(plan)
|
||||
entry, sl, tp = plan
|
||||
self.assertAlmostEqual(entry, 99900.0)
|
||||
self.assertAlmostEqual(sl, 99400.5)
|
||||
self.assertAlmostEqual(tp, 100649.25)
|
||||
|
||||
def test_storage_bounds(self):
|
||||
up, low = storage_bounds_from_key_price("short", 100000)
|
||||
self.assertGreater(up, low)
|
||||
self.assertAlmostEqual(up, 100000.0)
|
||||
self.assertAlmostEqual(low, 99990.0)
|
||||
up, low = storage_bounds_from_key_price("long", 100000)
|
||||
self.assertGreater(up, low)
|
||||
self.assertAlmostEqual(low, 100000.0)
|
||||
self.assertAlmostEqual(up, 100010.0)
|
||||
|
||||
def test_key_price_from_row(self):
|
||||
self.assertEqual(key_price_from_row("short", 100100, 100000), 100100)
|
||||
self.assertEqual(key_price_from_row("long", 100100, 100000), 100000)
|
||||
|
||||
def test_expiry(self):
|
||||
now = datetime(2026, 6, 9, 12, 0, 0)
|
||||
created = "2026-06-08 12:00:00"
|
||||
self.assertTrue(is_false_breakout_expired(created, now))
|
||||
self.assertFalse(is_false_breakout_expired(created, now - timedelta(hours=1)))
|
||||
|
||||
def test_monitor_type_constant(self):
|
||||
self.assertEqual(FALSE_BREAKOUT_MONITOR_TYPE, "假突破")
|
||||
|
||||
def test_gate_preview_not_box_gate(self):
|
||||
now = datetime(2026, 6, 7, 12, 0, 0)
|
||||
prev = false_breakout_gate_preview(
|
||||
entry_display="1635.0",
|
||||
limit_order_id="oid-1",
|
||||
created_at="2026-06-07 10:00:00",
|
||||
now=now,
|
||||
)
|
||||
self.assertIn("假突破", prev["summary"])
|
||||
self.assertIn("等待成交", prev["summary"])
|
||||
self.assertNotIn("量:", prev["summary"])
|
||||
self.assertIn("限价单:oid-1", prev["metrics"])
|
||||
self.assertTrue(prev["gate_ok"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from lib.key_monitor.false_breakout_key_monitor_lib import (
|
||||
FALSE_BREAKOUT_MONITOR_TYPE,
|
||||
calc_false_breakout_plan,
|
||||
false_breakout_gate_preview,
|
||||
is_false_breakout_expired,
|
||||
key_price_from_row,
|
||||
normalize_false_breakout_symbol,
|
||||
storage_bounds_from_key_price,
|
||||
)
|
||||
|
||||
|
||||
class FalseBreakoutKeyMonitorLibTests(unittest.TestCase):
|
||||
def test_normalize_symbol(self):
|
||||
self.assertEqual(normalize_false_breakout_symbol("btc"), "BTC/USDT")
|
||||
self.assertEqual(normalize_false_breakout_symbol("ETH/USDT"), "ETH/USDT")
|
||||
self.assertIsNone(normalize_false_breakout_symbol("SOL"))
|
||||
|
||||
def test_short_plan(self):
|
||||
plan = calc_false_breakout_plan("short", 100000)
|
||||
self.assertIsNotNone(plan)
|
||||
entry, sl, tp = plan
|
||||
self.assertAlmostEqual(entry, 100100.0)
|
||||
self.assertAlmostEqual(sl, 100600.5)
|
||||
self.assertAlmostEqual(tp, 99349.25)
|
||||
|
||||
def test_long_plan(self):
|
||||
plan = calc_false_breakout_plan("long", 100000)
|
||||
self.assertIsNotNone(plan)
|
||||
entry, sl, tp = plan
|
||||
self.assertAlmostEqual(entry, 99900.0)
|
||||
self.assertAlmostEqual(sl, 99400.5)
|
||||
self.assertAlmostEqual(tp, 100649.25)
|
||||
|
||||
def test_storage_bounds(self):
|
||||
up, low = storage_bounds_from_key_price("short", 100000)
|
||||
self.assertGreater(up, low)
|
||||
self.assertAlmostEqual(up, 100000.0)
|
||||
self.assertAlmostEqual(low, 99990.0)
|
||||
up, low = storage_bounds_from_key_price("long", 100000)
|
||||
self.assertGreater(up, low)
|
||||
self.assertAlmostEqual(low, 100000.0)
|
||||
self.assertAlmostEqual(up, 100010.0)
|
||||
|
||||
def test_key_price_from_row(self):
|
||||
self.assertEqual(key_price_from_row("short", 100100, 100000), 100100)
|
||||
self.assertEqual(key_price_from_row("long", 100100, 100000), 100000)
|
||||
|
||||
def test_expiry(self):
|
||||
now = datetime(2026, 6, 9, 12, 0, 0)
|
||||
created = "2026-06-08 12:00:00"
|
||||
self.assertTrue(is_false_breakout_expired(created, now))
|
||||
self.assertFalse(is_false_breakout_expired(created, now - timedelta(hours=1)))
|
||||
|
||||
def test_monitor_type_constant(self):
|
||||
self.assertEqual(FALSE_BREAKOUT_MONITOR_TYPE, "假突破")
|
||||
|
||||
def test_gate_preview_not_box_gate(self):
|
||||
now = datetime(2026, 6, 7, 12, 0, 0)
|
||||
prev = false_breakout_gate_preview(
|
||||
entry_display="1635.0",
|
||||
limit_order_id="oid-1",
|
||||
created_at="2026-06-07 10:00:00",
|
||||
now=now,
|
||||
)
|
||||
self.assertIn("假突破", prev["summary"])
|
||||
self.assertIn("等待成交", prev["summary"])
|
||||
self.assertNotIn("量:", prev["summary"])
|
||||
self.assertIn("限价单:oid-1", prev["metrics"])
|
||||
self.assertTrue(prev["gate_ok"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
from gate_position_history_lib import pick_gate_position_close, unified_symbol_for_match
|
||||
|
||||
|
||||
def test_unified_symbol_strips_settle_suffix():
|
||||
assert unified_symbol_for_match("BTC/USDT:USDT") == "BTC/USDT"
|
||||
|
||||
|
||||
def test_pick_gate_position_close_matches_symbol_side_and_time():
|
||||
hist = [
|
||||
{
|
||||
"symbol_u": "SOL/USDT",
|
||||
"side": "short",
|
||||
"close_ms": 1_700_000_000_000,
|
||||
"open_ms": 1_699_999_000_000,
|
||||
"pnl": -1.25,
|
||||
"sync_key": "SOL_USDT|1|short",
|
||||
}
|
||||
]
|
||||
hit = pick_gate_position_close(
|
||||
hist,
|
||||
"SOL/USDT:USDT",
|
||||
"short",
|
||||
opened_at_ms=1_699_999_500_000,
|
||||
)
|
||||
assert hit is not None
|
||||
assert hit["pnl"] == -1.25
|
||||
from lib.exchange.gate_position_history_lib import pick_gate_position_close, unified_symbol_for_match
|
||||
|
||||
|
||||
def test_unified_symbol_strips_settle_suffix():
|
||||
assert unified_symbol_for_match("BTC/USDT:USDT") == "BTC/USDT"
|
||||
|
||||
|
||||
def test_pick_gate_position_close_matches_symbol_side_and_time():
|
||||
hist = [
|
||||
{
|
||||
"symbol_u": "SOL/USDT",
|
||||
"side": "short",
|
||||
"close_ms": 1_700_000_000_000,
|
||||
"open_ms": 1_699_999_000_000,
|
||||
"pnl": -1.25,
|
||||
"sync_key": "SOL_USDT|1|short",
|
||||
}
|
||||
]
|
||||
hit = pick_gate_position_close(
|
||||
hist,
|
||||
"SOL/USDT:USDT",
|
||||
"short",
|
||||
opened_at_ms=1_699_999_500_000,
|
||||
)
|
||||
assert hit is not None
|
||||
assert hit["pnl"] == -1.25
|
||||
|
||||
@@ -1,44 +1,44 @@
|
||||
"""gate_transfer_lib 单元测试。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import unittest
|
||||
|
||||
from gate_transfer_lib import count_auto_transfer_blockers
|
||||
|
||||
|
||||
class GateTransferLibTest(unittest.TestCase):
|
||||
def test_counts_order_monitors_first(self):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.execute("CREATE TABLE order_monitors (status TEXT)")
|
||||
conn.execute("CREATE TABLE trend_pullback_plans (status TEXT, first_order_done INTEGER)")
|
||||
conn.execute("INSERT INTO order_monitors VALUES ('active')")
|
||||
conn.execute("INSERT INTO trend_pullback_plans VALUES ('active', 1)")
|
||||
conn.commit()
|
||||
n = count_auto_transfer_blockers(conn, count_order_monitors=lambda c: 1)
|
||||
self.assertEqual(n, 1)
|
||||
conn.close()
|
||||
|
||||
def test_counts_trend_plan_when_no_order_monitors(self):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.execute("CREATE TABLE order_monitors (status TEXT)")
|
||||
conn.execute("CREATE TABLE trend_pullback_plans (status TEXT, first_order_done INTEGER)")
|
||||
conn.execute("INSERT INTO trend_pullback_plans VALUES ('active', 1)")
|
||||
conn.commit()
|
||||
n = count_auto_transfer_blockers(conn, count_order_monitors=lambda c: 0)
|
||||
self.assertEqual(n, 1)
|
||||
conn.close()
|
||||
|
||||
def test_ignores_trend_plan_without_first_order(self):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.execute("CREATE TABLE order_monitors (status TEXT)")
|
||||
conn.execute("CREATE TABLE trend_pullback_plans (status TEXT, first_order_done INTEGER)")
|
||||
conn.execute("INSERT INTO trend_pullback_plans VALUES ('active', 0)")
|
||||
conn.commit()
|
||||
n = count_auto_transfer_blockers(conn, count_order_monitors=lambda c: 0)
|
||||
self.assertEqual(n, 0)
|
||||
conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""gate_transfer_lib 单元测试。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import unittest
|
||||
|
||||
from lib.exchange.gate_transfer_lib import count_auto_transfer_blockers
|
||||
|
||||
|
||||
class GateTransferLibTest(unittest.TestCase):
|
||||
def test_counts_order_monitors_first(self):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.execute("CREATE TABLE order_monitors (status TEXT)")
|
||||
conn.execute("CREATE TABLE trend_pullback_plans (status TEXT, first_order_done INTEGER)")
|
||||
conn.execute("INSERT INTO order_monitors VALUES ('active')")
|
||||
conn.execute("INSERT INTO trend_pullback_plans VALUES ('active', 1)")
|
||||
conn.commit()
|
||||
n = count_auto_transfer_blockers(conn, count_order_monitors=lambda c: 1)
|
||||
self.assertEqual(n, 1)
|
||||
conn.close()
|
||||
|
||||
def test_counts_trend_plan_when_no_order_monitors(self):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.execute("CREATE TABLE order_monitors (status TEXT)")
|
||||
conn.execute("CREATE TABLE trend_pullback_plans (status TEXT, first_order_done INTEGER)")
|
||||
conn.execute("INSERT INTO trend_pullback_plans VALUES ('active', 1)")
|
||||
conn.commit()
|
||||
n = count_auto_transfer_blockers(conn, count_order_monitors=lambda c: 0)
|
||||
self.assertEqual(n, 1)
|
||||
conn.close()
|
||||
|
||||
def test_ignores_trend_plan_without_first_order(self):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.execute("CREATE TABLE order_monitors (status TEXT)")
|
||||
conn.execute("CREATE TABLE trend_pullback_plans (status TEXT, first_order_done INTEGER)")
|
||||
conn.execute("INSERT INTO trend_pullback_plans VALUES ('active', 0)")
|
||||
conn.commit()
|
||||
n = count_auto_transfer_blockers(conn, count_order_monitors=lambda c: 0)
|
||||
self.assertEqual(n, 0)
|
||||
conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,94 +1,94 @@
|
||||
"""子代理持仓:四所标记价字段统一解析。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT / "manual_trading_hub"))
|
||||
|
||||
from agent import _position_mark_price, _ticker_mark_price # noqa: E402
|
||||
|
||||
sys.path.insert(0, str(ROOT))
|
||||
from hub_position_metrics import ( # noqa: E402
|
||||
enrich_ccxt_position_metrics_out,
|
||||
estimate_linear_swap_upnl_usdt,
|
||||
parse_position_unrealized_pnl,
|
||||
resolve_position_display_upnl,
|
||||
)
|
||||
|
||||
|
||||
class TestHubAgentMarkPrice(unittest.TestCase):
|
||||
def test_binance_mark_price(self):
|
||||
px = _position_mark_price({"markPrice": 65880.1, "info": {}})
|
||||
self.assertAlmostEqual(px, 65880.1)
|
||||
|
||||
def test_okx_mark_px(self):
|
||||
px = _position_mark_price({"info": {"markPx": "72.85"}})
|
||||
self.assertAlmostEqual(px, 72.85)
|
||||
|
||||
def test_gate_info_mark(self):
|
||||
px = _position_mark_price({"info": {"mark_price": "0.2241"}})
|
||||
self.assertAlmostEqual(px, 0.2241)
|
||||
|
||||
def test_missing_returns_none(self):
|
||||
self.assertIsNone(_position_mark_price({"info": {}}))
|
||||
|
||||
def test_infer_from_notional_and_contracts(self):
|
||||
p = {"notional": 1000, "contracts": 10, "info": {}}
|
||||
px = _position_mark_price(p)
|
||||
self.assertAlmostEqual(px, 100.0)
|
||||
|
||||
def test_ticker_fallback(self):
|
||||
class _Ex:
|
||||
def fetch_ticker(self, sym):
|
||||
return {"mark": 99.5, "info": {}}
|
||||
|
||||
self.assertAlmostEqual(_ticker_mark_price(_Ex(), "BTC/USDT:USDT"), 99.5)
|
||||
|
||||
def test_gate_unrealised_pnl_in_info(self):
|
||||
pnl = parse_position_unrealized_pnl(
|
||||
{"info": {"unrealised_pnl": "6.81"}, "unrealizedPnl": None}
|
||||
)
|
||||
self.assertAlmostEqual(pnl, 6.81)
|
||||
|
||||
def test_okx_upl_signed(self):
|
||||
pnl = parse_position_unrealized_pnl(
|
||||
{"info": {"upl": "-2.15"}, "unrealizedPnl": None}
|
||||
)
|
||||
self.assertAlmostEqual(pnl, -2.15)
|
||||
|
||||
def test_enrich_aligns_short_gate_metrics(self):
|
||||
pos = {
|
||||
"side": "short",
|
||||
"contracts": 11,
|
||||
"entryPrice": 73.187,
|
||||
"markPrice": 66.038,
|
||||
"info": {"unrealised_pnl": "7.86"},
|
||||
}
|
||||
out = {"unrealized_pnl": 7.86, "mark_price": 66.038}
|
||||
enrich_ccxt_position_metrics_out(pos, out, contract_size=1.0, funds_decimals=2)
|
||||
self.assertGreater(out["unrealized_pnl"], 70.0)
|
||||
|
||||
def test_estimate_short_hype_contract_size(self):
|
||||
upnl = estimate_linear_swap_upnl_usdt(
|
||||
"short", 73.187, 66.038, 11, 0.1
|
||||
)
|
||||
self.assertAlmostEqual(upnl, 7.86, places=1)
|
||||
|
||||
def test_resolve_prefers_computed_when_exchange_off(self):
|
||||
shown = resolve_position_display_upnl(
|
||||
"short", 73.187, 66.038, 11, 1.0, 7.86
|
||||
)
|
||||
self.assertAlmostEqual(shown, 78.64, places=1)
|
||||
|
||||
def test_resolve_keeps_exchange_when_aligned(self):
|
||||
shown = resolve_position_display_upnl(
|
||||
"short", 73.187, 66.038, 11, 0.1, 7.86
|
||||
)
|
||||
self.assertAlmostEqual(shown, 7.86, places=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""子代理持仓:四所标记价字段统一解析。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT / "manual_trading_hub"))
|
||||
|
||||
from agent import _position_mark_price, _ticker_mark_price # noqa: E402
|
||||
|
||||
sys.path.insert(0, str(ROOT))
|
||||
from lib.hub.hub_position_metrics import ( # noqa: E402
|
||||
enrich_ccxt_position_metrics_out,
|
||||
estimate_linear_swap_upnl_usdt,
|
||||
parse_position_unrealized_pnl,
|
||||
resolve_position_display_upnl,
|
||||
)
|
||||
|
||||
|
||||
class TestHubAgentMarkPrice(unittest.TestCase):
|
||||
def test_binance_mark_price(self):
|
||||
px = _position_mark_price({"markPrice": 65880.1, "info": {}})
|
||||
self.assertAlmostEqual(px, 65880.1)
|
||||
|
||||
def test_okx_mark_px(self):
|
||||
px = _position_mark_price({"info": {"markPx": "72.85"}})
|
||||
self.assertAlmostEqual(px, 72.85)
|
||||
|
||||
def test_gate_info_mark(self):
|
||||
px = _position_mark_price({"info": {"mark_price": "0.2241"}})
|
||||
self.assertAlmostEqual(px, 0.2241)
|
||||
|
||||
def test_missing_returns_none(self):
|
||||
self.assertIsNone(_position_mark_price({"info": {}}))
|
||||
|
||||
def test_infer_from_notional_and_contracts(self):
|
||||
p = {"notional": 1000, "contracts": 10, "info": {}}
|
||||
px = _position_mark_price(p)
|
||||
self.assertAlmostEqual(px, 100.0)
|
||||
|
||||
def test_ticker_fallback(self):
|
||||
class _Ex:
|
||||
def fetch_ticker(self, sym):
|
||||
return {"mark": 99.5, "info": {}}
|
||||
|
||||
self.assertAlmostEqual(_ticker_mark_price(_Ex(), "BTC/USDT:USDT"), 99.5)
|
||||
|
||||
def test_gate_unrealised_pnl_in_info(self):
|
||||
pnl = parse_position_unrealized_pnl(
|
||||
{"info": {"unrealised_pnl": "6.81"}, "unrealizedPnl": None}
|
||||
)
|
||||
self.assertAlmostEqual(pnl, 6.81)
|
||||
|
||||
def test_okx_upl_signed(self):
|
||||
pnl = parse_position_unrealized_pnl(
|
||||
{"info": {"upl": "-2.15"}, "unrealizedPnl": None}
|
||||
)
|
||||
self.assertAlmostEqual(pnl, -2.15)
|
||||
|
||||
def test_enrich_aligns_short_gate_metrics(self):
|
||||
pos = {
|
||||
"side": "short",
|
||||
"contracts": 11,
|
||||
"entryPrice": 73.187,
|
||||
"markPrice": 66.038,
|
||||
"info": {"unrealised_pnl": "7.86"},
|
||||
}
|
||||
out = {"unrealized_pnl": 7.86, "mark_price": 66.038}
|
||||
enrich_ccxt_position_metrics_out(pos, out, contract_size=1.0, funds_decimals=2)
|
||||
self.assertGreater(out["unrealized_pnl"], 70.0)
|
||||
|
||||
def test_estimate_short_hype_contract_size(self):
|
||||
upnl = estimate_linear_swap_upnl_usdt(
|
||||
"short", 73.187, 66.038, 11, 0.1
|
||||
)
|
||||
self.assertAlmostEqual(upnl, 7.86, places=1)
|
||||
|
||||
def test_resolve_prefers_computed_when_exchange_off(self):
|
||||
shown = resolve_position_display_upnl(
|
||||
"short", 73.187, 66.038, 11, 1.0, 7.86
|
||||
)
|
||||
self.assertAlmostEqual(shown, 78.64, places=1)
|
||||
|
||||
def test_resolve_keeps_exchange_when_aligned(self):
|
||||
shown = resolve_position_display_upnl(
|
||||
"short", 73.187, 66.038, 11, 0.1, 7.86
|
||||
)
|
||||
self.assertAlmostEqual(shown, 7.86, places=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
+162
-162
@@ -1,162 +1,162 @@
|
||||
"""hub_calculator_lib 测算逻辑。"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from hub_calculator_lib import (
|
||||
calc_initial_roll_qty,
|
||||
calc_roll_calculator,
|
||||
calc_trend_calculator,
|
||||
solve_add_amount_for_total_risk,
|
||||
)
|
||||
|
||||
MOCK_MARKET = {
|
||||
"exchange_id": "0",
|
||||
"exchange_key": "binance",
|
||||
"exchange_name": "币安 · crypto_monitor_binance",
|
||||
"exchange_label": "币安 · crypto_monitor_binance",
|
||||
"base": "ETH",
|
||||
"exchange_symbol": "ETH/USDT:USDT",
|
||||
"display_symbol": "ETH/USDT",
|
||||
"contract_size": 1.0,
|
||||
"price_tick": 0.01,
|
||||
"price_decimals": 2,
|
||||
"amount_decimals": 3,
|
||||
"min_amount": 0.001,
|
||||
}
|
||||
|
||||
|
||||
def _mock_resolve(_exchange="binance", _base="ETH"):
|
||||
return MOCK_MARKET, lambda amount: round(float(amount), 3), None
|
||||
|
||||
|
||||
class HubCalculatorLibTests(unittest.TestCase):
|
||||
@patch("hub_calculator_lib._resolve_market", return_value=_mock_resolve())
|
||||
def test_trend_calculator_long_basic(self, _mock):
|
||||
data, err = calc_trend_calculator(
|
||||
direction="long",
|
||||
capital_usdt=1000,
|
||||
risk_percent=5,
|
||||
leverage=5,
|
||||
entry_price=100,
|
||||
stop_loss=95,
|
||||
add_upper=110,
|
||||
take_profit=120,
|
||||
dca_legs=3,
|
||||
exchange_id="0",
|
||||
base="ETH",
|
||||
)
|
||||
self.assertIsNone(err)
|
||||
self.assertIsNotNone(data)
|
||||
assert data is not None
|
||||
self.assertEqual(data["risk_budget_u"], 50.0)
|
||||
self.assertGreaterEqual(len(data["rows"]), 2)
|
||||
self.assertEqual(data["rows"][0]["label"], "首仓")
|
||||
self.assertEqual(data["market"]["display_symbol"], "ETH/USDT")
|
||||
|
||||
@patch("hub_calculator_lib._resolve_market", return_value=_mock_resolve())
|
||||
def test_trend_calculator_short_rejects_bad_bounds(self, _mock):
|
||||
data, err = calc_trend_calculator(
|
||||
direction="short",
|
||||
capital_usdt=1000,
|
||||
risk_percent=5,
|
||||
leverage=5,
|
||||
entry_price=100,
|
||||
stop_loss=90,
|
||||
add_upper=110,
|
||||
take_profit=80,
|
||||
dca_legs=3,
|
||||
)
|
||||
self.assertIsNone(data)
|
||||
self.assertIsNotNone(err)
|
||||
|
||||
@patch("hub_calculator_lib._resolve_market", return_value=_mock_resolve())
|
||||
def test_roll_calculator_first_leg_auto(self, _mock):
|
||||
data, err = calc_roll_calculator(
|
||||
direction="long",
|
||||
capital_usdt=1000,
|
||||
risk_percent=5,
|
||||
entry_price=100,
|
||||
stop_loss=95,
|
||||
take_profit=120,
|
||||
add_legs=[],
|
||||
legs_done=0,
|
||||
)
|
||||
self.assertIsNone(err)
|
||||
self.assertIsNotNone(data)
|
||||
assert data is not None
|
||||
self.assertEqual(data["first_contracts"], 10.0)
|
||||
self.assertEqual(len(data["rows"]), 1)
|
||||
self.assertEqual(data["rows"][0]["loss_at_sl_u"], 50.0)
|
||||
self.assertEqual(data["rows"][0]["profit_at_tp_u"], 200.0)
|
||||
|
||||
@patch("hub_calculator_lib._resolve_market", return_value=_mock_resolve())
|
||||
def test_roll_calculator_chain_two_legs(self, _mock):
|
||||
data, err = calc_roll_calculator(
|
||||
direction="long",
|
||||
capital_usdt=1000,
|
||||
risk_percent=5,
|
||||
entry_price=100,
|
||||
stop_loss=95,
|
||||
take_profit=120,
|
||||
add_legs=[
|
||||
{"add_price": 105, "new_stop_loss": 98},
|
||||
{"add_price": 108, "new_stop_loss": 101},
|
||||
],
|
||||
legs_done=0,
|
||||
)
|
||||
self.assertIsNone(err)
|
||||
self.assertIsNotNone(data)
|
||||
assert data is not None
|
||||
self.assertEqual(len(data["rows"]), 3)
|
||||
self.assertEqual(data["rows"][1]["label"], "滚仓1")
|
||||
self.assertGreater(float(data["final_contracts"]), float(data["first_contracts"]))
|
||||
|
||||
@patch("hub_calculator_lib._resolve_market", return_value=_mock_resolve())
|
||||
def test_roll_calculator_rejects_too_many_legs(self, _mock):
|
||||
data, err = calc_roll_calculator(
|
||||
direction="long",
|
||||
capital_usdt=1000,
|
||||
risk_percent=5,
|
||||
entry_price=100,
|
||||
stop_loss=95,
|
||||
take_profit=120,
|
||||
add_legs=[
|
||||
{"add_price": 105, "new_stop_loss": 98},
|
||||
{"add_price": 108, "new_stop_loss": 101},
|
||||
{"add_price": 110, "new_stop_loss": 103},
|
||||
{"add_price": 112, "new_stop_loss": 105},
|
||||
],
|
||||
legs_done=0,
|
||||
)
|
||||
self.assertIsNone(data)
|
||||
self.assertIsNotNone(err)
|
||||
|
||||
def test_initial_roll_qty(self):
|
||||
qty, err = calc_initial_roll_qty("long", 100, 95, 50, 1.0)
|
||||
self.assertIsNone(err)
|
||||
self.assertEqual(qty, 10.0)
|
||||
|
||||
def test_initial_roll_qty_with_contract_size(self):
|
||||
qty, err = calc_initial_roll_qty("long", 100, 95, 50, 0.1)
|
||||
self.assertIsNone(err)
|
||||
self.assertEqual(qty, 100.0)
|
||||
|
||||
def test_solve_add_with_contract_size(self):
|
||||
q2, err = solve_add_amount_for_total_risk(
|
||||
"long",
|
||||
qty_existing=10.0,
|
||||
entry_existing=100.0,
|
||||
add_price=105.0,
|
||||
new_stop=98.0,
|
||||
risk_budget_usdt=50.0,
|
||||
contract_size=1.0,
|
||||
)
|
||||
self.assertIsNone(err)
|
||||
self.assertIsNotNone(q2)
|
||||
assert q2 is not None
|
||||
self.assertGreater(q2, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""hub_calculator_lib 测算逻辑。"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from lib.hub.hub_calculator_lib import (
|
||||
calc_initial_roll_qty,
|
||||
calc_roll_calculator,
|
||||
calc_trend_calculator,
|
||||
solve_add_amount_for_total_risk,
|
||||
)
|
||||
|
||||
MOCK_MARKET = {
|
||||
"exchange_id": "0",
|
||||
"exchange_key": "binance",
|
||||
"exchange_name": "币安 · crypto_monitor_binance",
|
||||
"exchange_label": "币安 · crypto_monitor_binance",
|
||||
"base": "ETH",
|
||||
"exchange_symbol": "ETH/USDT:USDT",
|
||||
"display_symbol": "ETH/USDT",
|
||||
"contract_size": 1.0,
|
||||
"price_tick": 0.01,
|
||||
"price_decimals": 2,
|
||||
"amount_decimals": 3,
|
||||
"min_amount": 0.001,
|
||||
}
|
||||
|
||||
|
||||
def _mock_resolve(_exchange="binance", _base="ETH"):
|
||||
return MOCK_MARKET, lambda amount: round(float(amount), 3), None
|
||||
|
||||
|
||||
class HubCalculatorLibTests(unittest.TestCase):
|
||||
@patch("lib.hub.hub_calculator_lib._resolve_market", return_value=_mock_resolve())
|
||||
def test_trend_calculator_long_basic(self, _mock):
|
||||
data, err = calc_trend_calculator(
|
||||
direction="long",
|
||||
capital_usdt=1000,
|
||||
risk_percent=5,
|
||||
leverage=5,
|
||||
entry_price=100,
|
||||
stop_loss=95,
|
||||
add_upper=110,
|
||||
take_profit=120,
|
||||
dca_legs=3,
|
||||
exchange_id="0",
|
||||
base="ETH",
|
||||
)
|
||||
self.assertIsNone(err)
|
||||
self.assertIsNotNone(data)
|
||||
assert data is not None
|
||||
self.assertEqual(data["risk_budget_u"], 50.0)
|
||||
self.assertGreaterEqual(len(data["rows"]), 2)
|
||||
self.assertEqual(data["rows"][0]["label"], "首仓")
|
||||
self.assertEqual(data["market"]["display_symbol"], "ETH/USDT")
|
||||
|
||||
@patch("lib.hub.hub_calculator_lib._resolve_market", return_value=_mock_resolve())
|
||||
def test_trend_calculator_short_rejects_bad_bounds(self, _mock):
|
||||
data, err = calc_trend_calculator(
|
||||
direction="short",
|
||||
capital_usdt=1000,
|
||||
risk_percent=5,
|
||||
leverage=5,
|
||||
entry_price=100,
|
||||
stop_loss=90,
|
||||
add_upper=110,
|
||||
take_profit=80,
|
||||
dca_legs=3,
|
||||
)
|
||||
self.assertIsNone(data)
|
||||
self.assertIsNotNone(err)
|
||||
|
||||
@patch("lib.hub.hub_calculator_lib._resolve_market", return_value=_mock_resolve())
|
||||
def test_roll_calculator_first_leg_auto(self, _mock):
|
||||
data, err = calc_roll_calculator(
|
||||
direction="long",
|
||||
capital_usdt=1000,
|
||||
risk_percent=5,
|
||||
entry_price=100,
|
||||
stop_loss=95,
|
||||
take_profit=120,
|
||||
add_legs=[],
|
||||
legs_done=0,
|
||||
)
|
||||
self.assertIsNone(err)
|
||||
self.assertIsNotNone(data)
|
||||
assert data is not None
|
||||
self.assertEqual(data["first_contracts"], 10.0)
|
||||
self.assertEqual(len(data["rows"]), 1)
|
||||
self.assertEqual(data["rows"][0]["loss_at_sl_u"], 50.0)
|
||||
self.assertEqual(data["rows"][0]["profit_at_tp_u"], 200.0)
|
||||
|
||||
@patch("lib.hub.hub_calculator_lib._resolve_market", return_value=_mock_resolve())
|
||||
def test_roll_calculator_chain_two_legs(self, _mock):
|
||||
data, err = calc_roll_calculator(
|
||||
direction="long",
|
||||
capital_usdt=1000,
|
||||
risk_percent=5,
|
||||
entry_price=100,
|
||||
stop_loss=95,
|
||||
take_profit=120,
|
||||
add_legs=[
|
||||
{"add_price": 105, "new_stop_loss": 98},
|
||||
{"add_price": 108, "new_stop_loss": 101},
|
||||
],
|
||||
legs_done=0,
|
||||
)
|
||||
self.assertIsNone(err)
|
||||
self.assertIsNotNone(data)
|
||||
assert data is not None
|
||||
self.assertEqual(len(data["rows"]), 3)
|
||||
self.assertEqual(data["rows"][1]["label"], "滚仓1")
|
||||
self.assertGreater(float(data["final_contracts"]), float(data["first_contracts"]))
|
||||
|
||||
@patch("lib.hub.hub_calculator_lib._resolve_market", return_value=_mock_resolve())
|
||||
def test_roll_calculator_rejects_too_many_legs(self, _mock):
|
||||
data, err = calc_roll_calculator(
|
||||
direction="long",
|
||||
capital_usdt=1000,
|
||||
risk_percent=5,
|
||||
entry_price=100,
|
||||
stop_loss=95,
|
||||
take_profit=120,
|
||||
add_legs=[
|
||||
{"add_price": 105, "new_stop_loss": 98},
|
||||
{"add_price": 108, "new_stop_loss": 101},
|
||||
{"add_price": 110, "new_stop_loss": 103},
|
||||
{"add_price": 112, "new_stop_loss": 105},
|
||||
],
|
||||
legs_done=0,
|
||||
)
|
||||
self.assertIsNone(data)
|
||||
self.assertIsNotNone(err)
|
||||
|
||||
def test_initial_roll_qty(self):
|
||||
qty, err = calc_initial_roll_qty("long", 100, 95, 50, 1.0)
|
||||
self.assertIsNone(err)
|
||||
self.assertEqual(qty, 10.0)
|
||||
|
||||
def test_initial_roll_qty_with_contract_size(self):
|
||||
qty, err = calc_initial_roll_qty("long", 100, 95, 50, 0.1)
|
||||
self.assertIsNone(err)
|
||||
self.assertEqual(qty, 100.0)
|
||||
|
||||
def test_solve_add_with_contract_size(self):
|
||||
q2, err = solve_add_amount_for_total_risk(
|
||||
"long",
|
||||
qty_existing=10.0,
|
||||
entry_existing=100.0,
|
||||
add_price=105.0,
|
||||
new_stop=98.0,
|
||||
risk_budget_usdt=50.0,
|
||||
contract_size=1.0,
|
||||
)
|
||||
self.assertIsNone(err)
|
||||
self.assertIsNotNone(q2)
|
||||
assert q2 is not None
|
||||
self.assertGreater(q2, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,113 +1,113 @@
|
||||
"""hub_calculator_market_lib 合约解析。"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from hub_calculator_market_lib import (
|
||||
amount_decimals_from_exchange,
|
||||
find_exchange,
|
||||
get_calculator_market,
|
||||
list_calculator_exchanges,
|
||||
make_amount_precise_fn_from_market,
|
||||
normalize_base_symbol,
|
||||
resolve_usdt_perp_symbol,
|
||||
)
|
||||
|
||||
|
||||
class FakeExchange:
|
||||
def __init__(self, markets: dict):
|
||||
self.markets = markets
|
||||
|
||||
def market(self, symbol: str):
|
||||
return self.markets[symbol]
|
||||
|
||||
def amount_to_precision(self, symbol: str, amount: float) -> str:
|
||||
return f"{float(amount):.3f}"
|
||||
|
||||
|
||||
class HubCalculatorMarketLibTests(unittest.TestCase):
|
||||
def test_normalize_base_symbol(self):
|
||||
self.assertEqual(normalize_base_symbol("eth"), "ETH")
|
||||
self.assertEqual(normalize_base_symbol("ETH/USDT:USDT"), "ETH")
|
||||
self.assertEqual(normalize_base_symbol("ETHUSDT"), "ETH")
|
||||
|
||||
def test_resolve_usdt_perp_symbol(self):
|
||||
ex = FakeExchange(
|
||||
{
|
||||
"ETH/USDT:USDT": {
|
||||
"base": "ETH",
|
||||
"quote": "USDT",
|
||||
"swap": True,
|
||||
"active": True,
|
||||
"contractSize": 1.0,
|
||||
"limits": {"amount": {"min": 0.001}},
|
||||
"precision": {"price": 2, "amount": 3},
|
||||
}
|
||||
}
|
||||
)
|
||||
sym, err = resolve_usdt_perp_symbol(ex, "ETH")
|
||||
self.assertIsNone(err)
|
||||
self.assertEqual(sym, "ETH/USDT:USDT")
|
||||
|
||||
def test_amount_decimals_from_exchange(self):
|
||||
ex = FakeExchange({})
|
||||
self.assertEqual(amount_decimals_from_exchange(ex, "ETH/USDT:USDT"), 3)
|
||||
|
||||
def test_make_amount_precise_fn_from_market(self):
|
||||
fn = make_amount_precise_fn_from_market({"amount_decimals": 3, "min_amount": 0.001})
|
||||
self.assertEqual(fn(1.23456), 1.234)
|
||||
self.assertIsNone(fn(0.0001))
|
||||
|
||||
@patch.dict("os.environ", {"HUB_BRIDGE_TOKEN": "test-token"}, clear=False)
|
||||
def test_hub_headers_use_x_hub_token(self):
|
||||
from hub_calculator_market_lib import _hub_headers
|
||||
|
||||
self.assertEqual(_hub_headers(), {"X-Hub-Token": "test-token"})
|
||||
|
||||
@patch("hub_calculator_market_lib.fetch_instance_market_sync")
|
||||
def test_get_calculator_market_from_instance(self, fetch_mock):
|
||||
fetch_mock.return_value = {
|
||||
"ok": True,
|
||||
"base": "ETH",
|
||||
"exchange_symbol": "ETH/USDT:USDT",
|
||||
"display_symbol": "ETH/USDT",
|
||||
"contract_size": 0.01,
|
||||
"price_tick": 0.01,
|
||||
"price_decimals": 2,
|
||||
"amount_decimals": 2,
|
||||
"min_amount": 0.01,
|
||||
}
|
||||
ex = {
|
||||
"id": "0",
|
||||
"key": "binance",
|
||||
"name": "币安 · crypto_monitor_binance",
|
||||
"enabled": True,
|
||||
"flask_url": "http://127.0.0.1:5001",
|
||||
}
|
||||
data, err = get_calculator_market("0", "ETH", ex=ex)
|
||||
self.assertIsNone(err)
|
||||
self.assertIsNotNone(data)
|
||||
assert data is not None
|
||||
self.assertEqual(data["exchange_id"], "0")
|
||||
self.assertEqual(data["exchange_name"], "币安 · crypto_monitor_binance")
|
||||
self.assertEqual(data["contract_size"], 0.01)
|
||||
|
||||
@patch("hub_calculator_market_lib.enabled_exchanges")
|
||||
def test_list_calculator_exchanges(self, enabled_mock):
|
||||
enabled_mock.return_value = [
|
||||
{"id": "0", "key": "binance", "name": "币安", "enabled": True},
|
||||
]
|
||||
rows = list_calculator_exchanges()
|
||||
self.assertEqual(len(rows), 1)
|
||||
self.assertEqual(rows[0]["id"], "0")
|
||||
|
||||
def test_find_exchange_by_id(self):
|
||||
with patch(
|
||||
"hub_calculator_market_lib.load_settings",
|
||||
return_value={"exchanges": [{"id": "2", "key": "gate", "name": "Gate"}]},
|
||||
):
|
||||
self.assertEqual(find_exchange("2")["key"], "gate")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""hub_calculator_market_lib 合约解析。"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from lib.hub.hub_calculator_market_lib import (
|
||||
amount_decimals_from_exchange,
|
||||
find_exchange,
|
||||
get_calculator_market,
|
||||
list_calculator_exchanges,
|
||||
make_amount_precise_fn_from_market,
|
||||
normalize_base_symbol,
|
||||
resolve_usdt_perp_symbol,
|
||||
)
|
||||
|
||||
|
||||
class FakeExchange:
|
||||
def __init__(self, markets: dict):
|
||||
self.markets = markets
|
||||
|
||||
def market(self, symbol: str):
|
||||
return self.markets[symbol]
|
||||
|
||||
def amount_to_precision(self, symbol: str, amount: float) -> str:
|
||||
return f"{float(amount):.3f}"
|
||||
|
||||
|
||||
class HubCalculatorMarketLibTests(unittest.TestCase):
|
||||
def test_normalize_base_symbol(self):
|
||||
self.assertEqual(normalize_base_symbol("eth"), "ETH")
|
||||
self.assertEqual(normalize_base_symbol("ETH/USDT:USDT"), "ETH")
|
||||
self.assertEqual(normalize_base_symbol("ETHUSDT"), "ETH")
|
||||
|
||||
def test_resolve_usdt_perp_symbol(self):
|
||||
ex = FakeExchange(
|
||||
{
|
||||
"ETH/USDT:USDT": {
|
||||
"base": "ETH",
|
||||
"quote": "USDT",
|
||||
"swap": True,
|
||||
"active": True,
|
||||
"contractSize": 1.0,
|
||||
"limits": {"amount": {"min": 0.001}},
|
||||
"precision": {"price": 2, "amount": 3},
|
||||
}
|
||||
}
|
||||
)
|
||||
sym, err = resolve_usdt_perp_symbol(ex, "ETH")
|
||||
self.assertIsNone(err)
|
||||
self.assertEqual(sym, "ETH/USDT:USDT")
|
||||
|
||||
def test_amount_decimals_from_exchange(self):
|
||||
ex = FakeExchange({})
|
||||
self.assertEqual(amount_decimals_from_exchange(ex, "ETH/USDT:USDT"), 3)
|
||||
|
||||
def test_make_amount_precise_fn_from_market(self):
|
||||
fn = make_amount_precise_fn_from_market({"amount_decimals": 3, "min_amount": 0.001})
|
||||
self.assertEqual(fn(1.23456), 1.234)
|
||||
self.assertIsNone(fn(0.0001))
|
||||
|
||||
@patch.dict("os.environ", {"HUB_BRIDGE_TOKEN": "test-token"}, clear=False)
|
||||
def test_hub_headers_use_x_hub_token(self):
|
||||
from lib.hub.hub_calculator_market_lib import _hub_headers
|
||||
|
||||
self.assertEqual(_hub_headers(), {"X-Hub-Token": "test-token"})
|
||||
|
||||
@patch("lib.hub.hub_calculator_market_lib.fetch_instance_market_sync")
|
||||
def test_get_calculator_market_from_instance(self, fetch_mock):
|
||||
fetch_mock.return_value = {
|
||||
"ok": True,
|
||||
"base": "ETH",
|
||||
"exchange_symbol": "ETH/USDT:USDT",
|
||||
"display_symbol": "ETH/USDT",
|
||||
"contract_size": 0.01,
|
||||
"price_tick": 0.01,
|
||||
"price_decimals": 2,
|
||||
"amount_decimals": 2,
|
||||
"min_amount": 0.01,
|
||||
}
|
||||
ex = {
|
||||
"id": "0",
|
||||
"key": "binance",
|
||||
"name": "币安 · crypto_monitor_binance",
|
||||
"enabled": True,
|
||||
"flask_url": "http://127.0.0.1:5001",
|
||||
}
|
||||
data, err = get_calculator_market("0", "ETH", ex=ex)
|
||||
self.assertIsNone(err)
|
||||
self.assertIsNotNone(data)
|
||||
assert data is not None
|
||||
self.assertEqual(data["exchange_id"], "0")
|
||||
self.assertEqual(data["exchange_name"], "币安 · crypto_monitor_binance")
|
||||
self.assertEqual(data["contract_size"], 0.01)
|
||||
|
||||
@patch("lib.hub.hub_calculator_market_lib.enabled_exchanges")
|
||||
def test_list_calculator_exchanges(self, enabled_mock):
|
||||
enabled_mock.return_value = [
|
||||
{"id": "0", "key": "binance", "name": "币安", "enabled": True},
|
||||
]
|
||||
rows = list_calculator_exchanges()
|
||||
self.assertEqual(len(rows), 1)
|
||||
self.assertEqual(rows[0]["id"], "0")
|
||||
|
||||
def test_find_exchange_by_id(self):
|
||||
with patch(
|
||||
"lib.hub.hub_calculator_market_lib.load_settings",
|
||||
return_value={"exchanges": [{"id": "2", "key": "gate", "name": "Gate"}]},
|
||||
):
|
||||
self.assertEqual(find_exchange("2")["key"], "gate")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
+157
-157
@@ -1,157 +1,157 @@
|
||||
"""开仓计划库:CRUD 与胜率统计。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from hub_entry_plan_lib import (
|
||||
compute_entry_plan_stats,
|
||||
create_entry_plan,
|
||||
delete_entry_plan,
|
||||
init_db,
|
||||
list_entry_plans,
|
||||
normalize_plan_symbol,
|
||||
resolve_stats_date_bounds,
|
||||
update_entry_plan,
|
||||
)
|
||||
|
||||
|
||||
def _base_payload(**overrides):
|
||||
data = {
|
||||
"plan_date": "2026-06-14",
|
||||
"exchange_key": "binance",
|
||||
"symbol": "BTC",
|
||||
"plan_type": "trend",
|
||||
"trend_timeframe": "4h",
|
||||
"entry_timeframe": "15m",
|
||||
"direction": "long",
|
||||
"target_level": "70000",
|
||||
"current_range": "68000-69000",
|
||||
"entry_scheme": "breakout",
|
||||
"note": "test",
|
||||
}
|
||||
data.update(overrides)
|
||||
return data
|
||||
|
||||
|
||||
def test_normalize_plan_symbol():
|
||||
assert normalize_plan_symbol("btc") == "BTC/USDT"
|
||||
assert normalize_plan_symbol("ETH/USDT") == "ETH/USDT"
|
||||
|
||||
|
||||
def test_create_without_entry_scheme():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "plans.db"
|
||||
payload = _base_payload()
|
||||
del payload["entry_scheme"]
|
||||
row = create_entry_plan(payload, db_path=db)
|
||||
assert row["entry_scheme"] == ""
|
||||
assert row["entry_scheme_label"] == "待填写"
|
||||
|
||||
|
||||
def test_archive_requires_entry_scheme():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "plans.db"
|
||||
payload = _base_payload()
|
||||
del payload["entry_scheme"]
|
||||
row = create_entry_plan(payload, db_path=db)
|
||||
try:
|
||||
update_entry_plan(int(row["id"]), {"result": "win"}, db_path=db)
|
||||
assert False, "expected ValueError"
|
||||
except ValueError as e:
|
||||
assert "入场方案" in str(e)
|
||||
updated = update_entry_plan(
|
||||
int(row["id"]),
|
||||
{"entry_scheme": "breakout", "result": "win"},
|
||||
db_path=db,
|
||||
)
|
||||
assert updated["status"] == "archived"
|
||||
|
||||
|
||||
def test_create_list_delete_active_plan():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "plans.db"
|
||||
row = create_entry_plan(_base_payload(), db_path=db)
|
||||
assert row["status"] == "active"
|
||||
assert row["symbol"] == "BTC/USDT"
|
||||
active = list_entry_plans(status="active", db_path=db)
|
||||
assert len(active) == 1
|
||||
assert delete_entry_plan(int(row["id"]), db_path=db) is True
|
||||
assert list_entry_plans(status="active", db_path=db) == []
|
||||
|
||||
|
||||
def test_archive_on_result():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "plans.db"
|
||||
row = create_entry_plan(_base_payload(symbol="SOL"), db_path=db)
|
||||
updated = update_entry_plan(
|
||||
int(row["id"]),
|
||||
{"result": "win", "pnl_amount": 12.5},
|
||||
db_path=db,
|
||||
)
|
||||
assert updated["status"] == "archived"
|
||||
assert updated["result"] == "win"
|
||||
assert updated["pnl_amount"] == 12.5
|
||||
assert list_entry_plans(status="active", db_path=db) == []
|
||||
archived = list_entry_plans(status="archived", db_path=db)
|
||||
assert len(archived) == 1
|
||||
|
||||
|
||||
def test_archive_without_pnl_amount():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "plans.db"
|
||||
row = create_entry_plan(_base_payload(symbol="DOGE"), db_path=db)
|
||||
updated = update_entry_plan(int(row["id"]), {"result": "loss"}, db_path=db)
|
||||
assert updated["status"] == "archived"
|
||||
assert updated["pnl_amount"] is None
|
||||
|
||||
|
||||
def test_cannot_delete_archived():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "plans.db"
|
||||
row = create_entry_plan(_base_payload(), db_path=db)
|
||||
update_entry_plan(int(row["id"]), {"result": "win"}, db_path=db)
|
||||
try:
|
||||
delete_entry_plan(int(row["id"]), db_path=db)
|
||||
assert False, "expected ValueError"
|
||||
except ValueError as e:
|
||||
assert "仅进行中" in str(e)
|
||||
|
||||
|
||||
def test_compute_stats_by_symbol():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "plans.db"
|
||||
for sym, res in (("BTC", "win"), ("BTC", "loss"), ("ETH", "win")):
|
||||
row = create_entry_plan(_base_payload(symbol=sym), db_path=db)
|
||||
update_entry_plan(int(row["id"]), {"result": res}, db_path=db)
|
||||
stats = compute_entry_plan_stats(dimension="symbol", period="all", db_path=db)
|
||||
by_sym = {it["key"]: it for it in stats["items"]}
|
||||
assert by_sym["BTC/USDT"]["win_count"] == 1
|
||||
assert by_sym["BTC/USDT"]["loss_count"] == 1
|
||||
assert by_sym["BTC/USDT"]["win_rate"] == 50.0
|
||||
assert by_sym["ETH/USDT"]["win_count"] == 1
|
||||
|
||||
|
||||
def test_stats_period_range_filter():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "plans.db"
|
||||
row1 = create_entry_plan(_base_payload(plan_date="2026-06-01"), db_path=db)
|
||||
row2 = create_entry_plan(_base_payload(plan_date="2026-06-20", symbol="ETH"), db_path=db)
|
||||
update_entry_plan(int(row1["id"]), {"result": "win"}, db_path=db)
|
||||
update_entry_plan(int(row2["id"]), {"result": "loss"}, db_path=db)
|
||||
stats = compute_entry_plan_stats(
|
||||
dimension="symbol",
|
||||
period="range",
|
||||
date_from="2026-06-01",
|
||||
date_to="2026-06-10",
|
||||
db_path=db,
|
||||
)
|
||||
assert len(stats["items"]) == 1
|
||||
assert stats["items"][0]["key"] == "BTC/USDT"
|
||||
|
||||
|
||||
def test_resolve_stats_date_bounds():
|
||||
df, dt, label = resolve_stats_date_bounds(period="all")
|
||||
assert df is None and dt is None
|
||||
assert "全部" in label
|
||||
"""开仓计划库:CRUD 与胜率统计。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from lib.hub.hub_entry_plan_lib import (
|
||||
compute_entry_plan_stats,
|
||||
create_entry_plan,
|
||||
delete_entry_plan,
|
||||
init_db,
|
||||
list_entry_plans,
|
||||
normalize_plan_symbol,
|
||||
resolve_stats_date_bounds,
|
||||
update_entry_plan,
|
||||
)
|
||||
|
||||
|
||||
def _base_payload(**overrides):
|
||||
data = {
|
||||
"plan_date": "2026-06-14",
|
||||
"exchange_key": "binance",
|
||||
"symbol": "BTC",
|
||||
"plan_type": "trend",
|
||||
"trend_timeframe": "4h",
|
||||
"entry_timeframe": "15m",
|
||||
"direction": "long",
|
||||
"target_level": "70000",
|
||||
"current_range": "68000-69000",
|
||||
"entry_scheme": "breakout",
|
||||
"note": "test",
|
||||
}
|
||||
data.update(overrides)
|
||||
return data
|
||||
|
||||
|
||||
def test_normalize_plan_symbol():
|
||||
assert normalize_plan_symbol("btc") == "BTC/USDT"
|
||||
assert normalize_plan_symbol("ETH/USDT") == "ETH/USDT"
|
||||
|
||||
|
||||
def test_create_without_entry_scheme():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "plans.db"
|
||||
payload = _base_payload()
|
||||
del payload["entry_scheme"]
|
||||
row = create_entry_plan(payload, db_path=db)
|
||||
assert row["entry_scheme"] == ""
|
||||
assert row["entry_scheme_label"] == "待填写"
|
||||
|
||||
|
||||
def test_archive_requires_entry_scheme():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "plans.db"
|
||||
payload = _base_payload()
|
||||
del payload["entry_scheme"]
|
||||
row = create_entry_plan(payload, db_path=db)
|
||||
try:
|
||||
update_entry_plan(int(row["id"]), {"result": "win"}, db_path=db)
|
||||
assert False, "expected ValueError"
|
||||
except ValueError as e:
|
||||
assert "入场方案" in str(e)
|
||||
updated = update_entry_plan(
|
||||
int(row["id"]),
|
||||
{"entry_scheme": "breakout", "result": "win"},
|
||||
db_path=db,
|
||||
)
|
||||
assert updated["status"] == "archived"
|
||||
|
||||
|
||||
def test_create_list_delete_active_plan():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "plans.db"
|
||||
row = create_entry_plan(_base_payload(), db_path=db)
|
||||
assert row["status"] == "active"
|
||||
assert row["symbol"] == "BTC/USDT"
|
||||
active = list_entry_plans(status="active", db_path=db)
|
||||
assert len(active) == 1
|
||||
assert delete_entry_plan(int(row["id"]), db_path=db) is True
|
||||
assert list_entry_plans(status="active", db_path=db) == []
|
||||
|
||||
|
||||
def test_archive_on_result():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "plans.db"
|
||||
row = create_entry_plan(_base_payload(symbol="SOL"), db_path=db)
|
||||
updated = update_entry_plan(
|
||||
int(row["id"]),
|
||||
{"result": "win", "pnl_amount": 12.5},
|
||||
db_path=db,
|
||||
)
|
||||
assert updated["status"] == "archived"
|
||||
assert updated["result"] == "win"
|
||||
assert updated["pnl_amount"] == 12.5
|
||||
assert list_entry_plans(status="active", db_path=db) == []
|
||||
archived = list_entry_plans(status="archived", db_path=db)
|
||||
assert len(archived) == 1
|
||||
|
||||
|
||||
def test_archive_without_pnl_amount():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "plans.db"
|
||||
row = create_entry_plan(_base_payload(symbol="DOGE"), db_path=db)
|
||||
updated = update_entry_plan(int(row["id"]), {"result": "loss"}, db_path=db)
|
||||
assert updated["status"] == "archived"
|
||||
assert updated["pnl_amount"] is None
|
||||
|
||||
|
||||
def test_cannot_delete_archived():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "plans.db"
|
||||
row = create_entry_plan(_base_payload(), db_path=db)
|
||||
update_entry_plan(int(row["id"]), {"result": "win"}, db_path=db)
|
||||
try:
|
||||
delete_entry_plan(int(row["id"]), db_path=db)
|
||||
assert False, "expected ValueError"
|
||||
except ValueError as e:
|
||||
assert "仅进行中" in str(e)
|
||||
|
||||
|
||||
def test_compute_stats_by_symbol():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "plans.db"
|
||||
for sym, res in (("BTC", "win"), ("BTC", "loss"), ("ETH", "win")):
|
||||
row = create_entry_plan(_base_payload(symbol=sym), db_path=db)
|
||||
update_entry_plan(int(row["id"]), {"result": res}, db_path=db)
|
||||
stats = compute_entry_plan_stats(dimension="symbol", period="all", db_path=db)
|
||||
by_sym = {it["key"]: it for it in stats["items"]}
|
||||
assert by_sym["BTC/USDT"]["win_count"] == 1
|
||||
assert by_sym["BTC/USDT"]["loss_count"] == 1
|
||||
assert by_sym["BTC/USDT"]["win_rate"] == 50.0
|
||||
assert by_sym["ETH/USDT"]["win_count"] == 1
|
||||
|
||||
|
||||
def test_stats_period_range_filter():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "plans.db"
|
||||
row1 = create_entry_plan(_base_payload(plan_date="2026-06-01"), db_path=db)
|
||||
row2 = create_entry_plan(_base_payload(plan_date="2026-06-20", symbol="ETH"), db_path=db)
|
||||
update_entry_plan(int(row1["id"]), {"result": "win"}, db_path=db)
|
||||
update_entry_plan(int(row2["id"]), {"result": "loss"}, db_path=db)
|
||||
stats = compute_entry_plan_stats(
|
||||
dimension="symbol",
|
||||
period="range",
|
||||
date_from="2026-06-01",
|
||||
date_to="2026-06-10",
|
||||
db_path=db,
|
||||
)
|
||||
assert len(stats["items"]) == 1
|
||||
assert stats["items"][0]["key"] == "BTC/USDT"
|
||||
|
||||
|
||||
def test_resolve_stats_date_bounds():
|
||||
df, dt, label = resolve_stats_date_bounds(period="all")
|
||||
assert df is None and dt is None
|
||||
assert "全部" in label
|
||||
|
||||
+113
-113
@@ -1,113 +1,113 @@
|
||||
"""hub_fund_history_lib:总资金、回撤与日快照。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from hub_fund_history_lib import (
|
||||
account_total_usdt,
|
||||
build_fund_overview,
|
||||
compute_drawdown,
|
||||
get_fund_history,
|
||||
record_fund_snapshot,
|
||||
)
|
||||
|
||||
|
||||
def test_account_total_requires_both_sides():
|
||||
assert account_total_usdt(10, 20) == 30.0
|
||||
assert account_total_usdt(10, None) is None
|
||||
assert account_total_usdt(None, 5) is None
|
||||
|
||||
|
||||
def test_compute_drawdown():
|
||||
dd = compute_drawdown([100, 120, 90, 110])
|
||||
assert dd["peak_usdt"] == 120.0
|
||||
assert dd["max_drawdown_u"] == 30.0
|
||||
assert dd["max_drawdown_pct"] == 25.0
|
||||
|
||||
|
||||
def test_build_fund_overview_skips_unmonitored(tmp_path, monkeypatch):
|
||||
hist_path = tmp_path / "hub_fund_history.json"
|
||||
monkeypatch.setattr("hub_fund_history_lib.FUND_HISTORY_PATH", hist_path)
|
||||
record_fund_snapshot(
|
||||
"2026-06-01",
|
||||
[
|
||||
{
|
||||
"key": "binance",
|
||||
"name": "Binance",
|
||||
"funding_usdt": 10,
|
||||
"trading_usdt": 20,
|
||||
"monitored": True,
|
||||
}
|
||||
],
|
||||
keep_days=180,
|
||||
)
|
||||
record_fund_snapshot(
|
||||
"2026-06-02",
|
||||
[
|
||||
{
|
||||
"key": "binance",
|
||||
"name": "Binance",
|
||||
"funding_usdt": 12,
|
||||
"trading_usdt": 18,
|
||||
"monitored": True,
|
||||
}
|
||||
],
|
||||
keep_days=180,
|
||||
)
|
||||
exchanges = [
|
||||
{"id": "0", "key": "binance", "name": "Binance", "enabled": True},
|
||||
{"id": "3", "key": "gate_bot", "name": "Gate Bot", "enabled": False},
|
||||
]
|
||||
board_rows = [
|
||||
{
|
||||
"key": "binance",
|
||||
"name": "Binance",
|
||||
"account_ok": True,
|
||||
"funding_usdt": 15,
|
||||
"trading_usdt": 25,
|
||||
}
|
||||
]
|
||||
out = build_fund_overview(
|
||||
exchanges,
|
||||
board_rows=board_rows,
|
||||
trading_day="2026-06-02",
|
||||
keep_days=180,
|
||||
)
|
||||
assert out["totals"]["total_usdt"] == 40.0
|
||||
assert out["totals"]["monitored_count"] == 1
|
||||
assert len(out["accounts"]) == 1
|
||||
assert all(a["monitored"] for a in out["accounts"])
|
||||
assert out["totals"]["drawdown"]["max_drawdown_u"] == 0.0
|
||||
|
||||
|
||||
def test_history_start_day_filters_older(tmp_path, monkeypatch):
|
||||
hist_path = tmp_path / "hub_fund_history.json"
|
||||
monkeypatch.setattr("hub_fund_history_lib.FUND_HISTORY_PATH", hist_path)
|
||||
monkeypatch.setattr("hub_fund_history_lib.FUND_HISTORY_START_DAY", "2026-06-09")
|
||||
record_fund_snapshot(
|
||||
"2026-06-01",
|
||||
[
|
||||
{
|
||||
"key": "binance",
|
||||
"name": "Binance",
|
||||
"funding_usdt": 1,
|
||||
"trading_usdt": 1,
|
||||
"monitored": True,
|
||||
}
|
||||
],
|
||||
keep_days=180,
|
||||
)
|
||||
record_fund_snapshot(
|
||||
"2026-06-09",
|
||||
[
|
||||
{
|
||||
"key": "binance",
|
||||
"name": "Binance",
|
||||
"funding_usdt": 10,
|
||||
"trading_usdt": 20,
|
||||
"monitored": True,
|
||||
}
|
||||
],
|
||||
keep_days=180,
|
||||
)
|
||||
hist = get_fund_history(anchor_day="2026-06-10", keep_days=180)
|
||||
assert "2026-06-01" not in hist
|
||||
assert "2026-06-09" in hist
|
||||
"""hub_fund_history_lib:总资金、回撤与日快照。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from lib.hub.hub_fund_history_lib import (
|
||||
account_total_usdt,
|
||||
build_fund_overview,
|
||||
compute_drawdown,
|
||||
get_fund_history,
|
||||
record_fund_snapshot,
|
||||
)
|
||||
|
||||
|
||||
def test_account_total_requires_both_sides():
|
||||
assert account_total_usdt(10, 20) == 30.0
|
||||
assert account_total_usdt(10, None) is None
|
||||
assert account_total_usdt(None, 5) is None
|
||||
|
||||
|
||||
def test_compute_drawdown():
|
||||
dd = compute_drawdown([100, 120, 90, 110])
|
||||
assert dd["peak_usdt"] == 120.0
|
||||
assert dd["max_drawdown_u"] == 30.0
|
||||
assert dd["max_drawdown_pct"] == 25.0
|
||||
|
||||
|
||||
def test_build_fund_overview_skips_unmonitored(tmp_path, monkeypatch):
|
||||
hist_path = tmp_path / "hub_fund_history.json"
|
||||
monkeypatch.setattr("hub_fund_history_lib.FUND_HISTORY_PATH", hist_path)
|
||||
record_fund_snapshot(
|
||||
"2026-06-01",
|
||||
[
|
||||
{
|
||||
"key": "binance",
|
||||
"name": "Binance",
|
||||
"funding_usdt": 10,
|
||||
"trading_usdt": 20,
|
||||
"monitored": True,
|
||||
}
|
||||
],
|
||||
keep_days=180,
|
||||
)
|
||||
record_fund_snapshot(
|
||||
"2026-06-02",
|
||||
[
|
||||
{
|
||||
"key": "binance",
|
||||
"name": "Binance",
|
||||
"funding_usdt": 12,
|
||||
"trading_usdt": 18,
|
||||
"monitored": True,
|
||||
}
|
||||
],
|
||||
keep_days=180,
|
||||
)
|
||||
exchanges = [
|
||||
{"id": "0", "key": "binance", "name": "Binance", "enabled": True},
|
||||
{"id": "3", "key": "gate_bot", "name": "Gate Bot", "enabled": False},
|
||||
]
|
||||
board_rows = [
|
||||
{
|
||||
"key": "binance",
|
||||
"name": "Binance",
|
||||
"account_ok": True,
|
||||
"funding_usdt": 15,
|
||||
"trading_usdt": 25,
|
||||
}
|
||||
]
|
||||
out = build_fund_overview(
|
||||
exchanges,
|
||||
board_rows=board_rows,
|
||||
trading_day="2026-06-02",
|
||||
keep_days=180,
|
||||
)
|
||||
assert out["totals"]["total_usdt"] == 40.0
|
||||
assert out["totals"]["monitored_count"] == 1
|
||||
assert len(out["accounts"]) == 1
|
||||
assert all(a["monitored"] for a in out["accounts"])
|
||||
assert out["totals"]["drawdown"]["max_drawdown_u"] == 0.0
|
||||
|
||||
|
||||
def test_history_start_day_filters_older(tmp_path, monkeypatch):
|
||||
hist_path = tmp_path / "hub_fund_history.json"
|
||||
monkeypatch.setattr("hub_fund_history_lib.FUND_HISTORY_PATH", hist_path)
|
||||
monkeypatch.setattr("hub_fund_history_lib.FUND_HISTORY_START_DAY", "2026-06-09")
|
||||
record_fund_snapshot(
|
||||
"2026-06-01",
|
||||
[
|
||||
{
|
||||
"key": "binance",
|
||||
"name": "Binance",
|
||||
"funding_usdt": 1,
|
||||
"trading_usdt": 1,
|
||||
"monitored": True,
|
||||
}
|
||||
],
|
||||
keep_days=180,
|
||||
)
|
||||
record_fund_snapshot(
|
||||
"2026-06-09",
|
||||
[
|
||||
{
|
||||
"key": "binance",
|
||||
"name": "Binance",
|
||||
"funding_usdt": 10,
|
||||
"trading_usdt": 20,
|
||||
"monitored": True,
|
||||
}
|
||||
],
|
||||
keep_days=180,
|
||||
)
|
||||
hist = get_fund_history(anchor_day="2026-06-10", keep_days=180)
|
||||
assert "2026-06-01" not in hist
|
||||
assert "2026-06-09" in hist
|
||||
|
||||
@@ -1,58 +1,58 @@
|
||||
"""hub_host_status_lib 单元测试。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from hub_host_status_lib import _disk_path, _state, get_host_status
|
||||
|
||||
|
||||
class HubHostStatusLibTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
_state["primed"] = False
|
||||
_state["net_ts"] = 0.0
|
||||
_state["net_sent"] = 0
|
||||
_state["net_recv"] = 0
|
||||
|
||||
def test_disk_path_env_override(self):
|
||||
with patch.dict("os.environ", {"HUB_HOST_DISK_PATH": "/data"}, clear=False):
|
||||
self.assertEqual(_disk_path(), "/data")
|
||||
|
||||
def test_get_host_status_without_psutil(self):
|
||||
import builtins
|
||||
|
||||
real_import = builtins.__import__
|
||||
|
||||
def fake_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||
if name == "psutil":
|
||||
raise ImportError("no psutil")
|
||||
return real_import(name, globals, locals, fromlist, level)
|
||||
|
||||
with patch("builtins.__import__", side_effect=fake_import):
|
||||
out = get_host_status()
|
||||
self.assertFalse(out.get("ok"))
|
||||
self.assertIn("psutil", out.get("msg", ""))
|
||||
|
||||
def test_get_host_status_payload(self):
|
||||
fake_vm = MagicMock(total=8_000_000_000, used=3_200_000_000, percent=40.0)
|
||||
fake_du = MagicMock(total=100_000_000_000, used=50_000_000_000)
|
||||
fake_net = MagicMock(bytes_sent=1_000_000, bytes_recv=2_000_000)
|
||||
fake_psutil = MagicMock()
|
||||
fake_psutil.cpu_percent.return_value = 12.5
|
||||
fake_psutil.cpu_count.return_value = 4
|
||||
fake_psutil.virtual_memory.return_value = fake_vm
|
||||
fake_psutil.disk_usage.return_value = fake_du
|
||||
fake_psutil.net_io_counters.return_value = fake_net
|
||||
fake_psutil.boot_time.return_value = 1_700_000_000.0
|
||||
with patch.dict(sys.modules, {"psutil": fake_psutil}):
|
||||
out = get_host_status()
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertEqual(out["cpu"]["percent"], 12.5)
|
||||
self.assertEqual(out["memory"]["percent"], 40.0)
|
||||
self.assertEqual(out["disk"]["percent"], 50.0)
|
||||
self.assertIn("network", out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""hub_host_status_lib 单元测试。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from lib.hub.hub_host_status_lib import _disk_path, _state, get_host_status
|
||||
|
||||
|
||||
class HubHostStatusLibTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
_state["primed"] = False
|
||||
_state["net_ts"] = 0.0
|
||||
_state["net_sent"] = 0
|
||||
_state["net_recv"] = 0
|
||||
|
||||
def test_disk_path_env_override(self):
|
||||
with patch.dict("os.environ", {"HUB_HOST_DISK_PATH": "/data"}, clear=False):
|
||||
self.assertEqual(_disk_path(), "/data")
|
||||
|
||||
def test_get_host_status_without_psutil(self):
|
||||
import builtins
|
||||
|
||||
real_import = builtins.__import__
|
||||
|
||||
def fake_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||
if name == "psutil":
|
||||
raise ImportError("no psutil")
|
||||
return real_import(name, globals, locals, fromlist, level)
|
||||
|
||||
with patch("builtins.__import__", side_effect=fake_import):
|
||||
out = get_host_status()
|
||||
self.assertFalse(out.get("ok"))
|
||||
self.assertIn("psutil", out.get("msg", ""))
|
||||
|
||||
def test_get_host_status_payload(self):
|
||||
fake_vm = MagicMock(total=8_000_000_000, used=3_200_000_000, percent=40.0)
|
||||
fake_du = MagicMock(total=100_000_000_000, used=50_000_000_000)
|
||||
fake_net = MagicMock(bytes_sent=1_000_000, bytes_recv=2_000_000)
|
||||
fake_psutil = MagicMock()
|
||||
fake_psutil.cpu_percent.return_value = 12.5
|
||||
fake_psutil.cpu_count.return_value = 4
|
||||
fake_psutil.virtual_memory.return_value = fake_vm
|
||||
fake_psutil.disk_usage.return_value = fake_du
|
||||
fake_psutil.net_io_counters.return_value = fake_net
|
||||
fake_psutil.boot_time.return_value = 1_700_000_000.0
|
||||
with patch.dict(sys.modules, {"psutil": fake_psutil}):
|
||||
out = get_host_status()
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertEqual(out["cpu"]["percent"], 12.5)
|
||||
self.assertEqual(out["memory"]["percent"], 40.0)
|
||||
self.assertEqual(out["disk"]["percent"], 50.0)
|
||||
self.assertIn("network", out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
+466
-466
@@ -1,466 +1,466 @@
|
||||
"""中控 K 线库:分周期保留、聚合与分页读取。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from hub_kline_store import (
|
||||
HUB_KLINE_REMOTE_FETCH_CAP,
|
||||
_since_ms_for_span,
|
||||
clear_series_bars,
|
||||
init_db,
|
||||
load_bars_before,
|
||||
load_bars_latest,
|
||||
purge_retention,
|
||||
purge_timeframe_by_days,
|
||||
resolve_chart_bars,
|
||||
retention_days,
|
||||
trim_contiguous_tail,
|
||||
upsert_bars,
|
||||
)
|
||||
from hub_ohlcv_lib import (
|
||||
TIMEFRAME_MS,
|
||||
bar_limit_for_timeframe,
|
||||
chart_fetch_start_ms,
|
||||
chart_initial_limit,
|
||||
last_closed_bar_open_ms,
|
||||
window_start_ms,
|
||||
)
|
||||
|
||||
|
||||
class TestHubKlineStore(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmp = tempfile.TemporaryDirectory()
|
||||
self.db = Path(self.tmp.name) / "test_hub_kline.db"
|
||||
|
||||
def tearDown(self):
|
||||
self.tmp.cleanup()
|
||||
|
||||
def test_bar_limits(self):
|
||||
self.assertEqual(bar_limit_for_timeframe("5m"), 5000)
|
||||
self.assertEqual(bar_limit_for_timeframe("1h"), 1000)
|
||||
self.assertEqual(bar_limit_for_timeframe("1d"), 1000)
|
||||
self.assertEqual(bar_limit_for_timeframe("1w"), 500)
|
||||
self.assertEqual(chart_initial_limit("5m"), 2000)
|
||||
self.assertEqual(chart_initial_limit("1h"), 1000)
|
||||
self.assertEqual(chart_initial_limit("1d"), 500)
|
||||
|
||||
def test_chart_fetch_window_exceeds_retention(self):
|
||||
now = int(time.time() * 1000)
|
||||
need = bar_limit_for_timeframe("1d")
|
||||
fetch_start = chart_fetch_start_ms("1d", need, now)
|
||||
db_start = window_start_ms("1d", need, retention_days(), now)
|
||||
self.assertLess(fetch_start, db_start)
|
||||
|
||||
def test_purge_retention_5m_one_year(self):
|
||||
init_db(self.db)
|
||||
old_ms = int(time.time() * 1000) - 400 * 86400000
|
||||
upsert_bars(
|
||||
"okx",
|
||||
"BTC/USDT",
|
||||
"5m",
|
||||
[
|
||||
{
|
||||
"open_time_ms": old_ms,
|
||||
"open": 1,
|
||||
"high": 2,
|
||||
"low": 0.5,
|
||||
"close": 1.5,
|
||||
"volume": 10,
|
||||
}
|
||||
],
|
||||
self.db,
|
||||
)
|
||||
n = purge_timeframe_by_days("5m", 365, self.db)
|
||||
self.assertGreaterEqual(n, 1)
|
||||
rows = load_bars_latest("okx", "BTC/USDT", "5m", 10, self.db)
|
||||
self.assertEqual(len(rows), 0)
|
||||
|
||||
def test_purge_retention_keeps_1d(self):
|
||||
init_db(self.db)
|
||||
old_ms = int(time.time() * 1000) - 400 * 86400000
|
||||
upsert_bars(
|
||||
"okx",
|
||||
"BTC/USDT",
|
||||
"1d",
|
||||
[
|
||||
{
|
||||
"open_time_ms": old_ms,
|
||||
"open": 1,
|
||||
"high": 2,
|
||||
"low": 0.5,
|
||||
"close": 1.5,
|
||||
"volume": 10,
|
||||
}
|
||||
],
|
||||
self.db,
|
||||
)
|
||||
purge_retention(self.db)
|
||||
rows = load_bars_latest("okx", "BTC/USDT", "1d", 10, self.db)
|
||||
self.assertEqual(len(rows), 1)
|
||||
|
||||
def test_resolve_uses_cache_without_remote(self):
|
||||
init_db(self.db)
|
||||
now = int(time.time() * 1000)
|
||||
tf = "5m"
|
||||
period = TIMEFRAME_MS[tf]
|
||||
last_closed = last_closed_bar_open_ms(tf, now)
|
||||
bars = []
|
||||
for i in range(400):
|
||||
oms = last_closed - (399 - i) * period
|
||||
bars.append(
|
||||
{
|
||||
"open_time_ms": oms,
|
||||
"open": 100 + i,
|
||||
"high": 101 + i,
|
||||
"low": 99 + i,
|
||||
"close": 100.5 + i,
|
||||
"volume": 1000 + i,
|
||||
}
|
||||
)
|
||||
upsert_bars("okx", "ETH/USDT", tf, bars, self.db)
|
||||
|
||||
def remote_fetch(**kwargs):
|
||||
self.fail("不应请求交易所")
|
||||
|
||||
out = resolve_chart_bars(
|
||||
"okx",
|
||||
"ETH/USDT",
|
||||
tf,
|
||||
remote_fetch,
|
||||
db_path=self.db,
|
||||
limit=300,
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertEqual(len(out.get("candles") or []), 300)
|
||||
|
||||
def test_resolve_15m_reads_native_bars(self):
|
||||
init_db(self.db)
|
||||
now = int(time.time() * 1000)
|
||||
period = TIMEFRAME_MS["15m"]
|
||||
last_closed = last_closed_bar_open_ms("15m", now)
|
||||
bars = []
|
||||
for i in range(12):
|
||||
oms = last_closed - (11 - i) * period
|
||||
bars.append(
|
||||
{
|
||||
"open_time_ms": oms,
|
||||
"open": 1.0 + i,
|
||||
"high": 2.0 + i,
|
||||
"low": 0.5 + i,
|
||||
"close": 1.5 + i,
|
||||
"volume": 10.0,
|
||||
}
|
||||
)
|
||||
upsert_bars("okx", "ETH/USDT", "15m", bars, self.db)
|
||||
|
||||
def remote_fetch(**kwargs):
|
||||
self.fail("不应请求交易所")
|
||||
|
||||
out = resolve_chart_bars(
|
||||
"okx",
|
||||
"ETH/USDT",
|
||||
"15m",
|
||||
remote_fetch,
|
||||
db_path=self.db,
|
||||
limit=10,
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertEqual(out.get("source"), "db")
|
||||
self.assertEqual(out.get("storage_timeframe"), "15m")
|
||||
self.assertGreaterEqual(len(out.get("candles") or []), 10)
|
||||
|
||||
def test_load_bars_before(self):
|
||||
init_db(self.db)
|
||||
period = TIMEFRAME_MS["1h"]
|
||||
base = 1_700_000_000_000
|
||||
bars = []
|
||||
for i in range(5):
|
||||
bars.append(
|
||||
{
|
||||
"open_time_ms": base + i * period,
|
||||
"open": 1,
|
||||
"high": 2,
|
||||
"low": 0.5,
|
||||
"close": 1.5,
|
||||
"volume": 1,
|
||||
}
|
||||
)
|
||||
upsert_bars("okx", "BTC/USDT", "1h", bars, self.db)
|
||||
before = base + 3 * period
|
||||
got = load_bars_before("okx", "BTC/USDT", "1h", before, 2, self.db)
|
||||
self.assertEqual(len(got), 2)
|
||||
self.assertEqual(got[-1]["open_time_ms"], base + 2 * period)
|
||||
|
||||
def test_trim_contiguous_tail_drops_orphan_prefix(self):
|
||||
period = TIMEFRAME_MS["15m"]
|
||||
base_old = 1_700_000_000_000
|
||||
base_new = base_old + period * 500
|
||||
bars = []
|
||||
for i in range(3):
|
||||
bars.append(
|
||||
{
|
||||
"open_time_ms": base_old + i * period,
|
||||
"open": 1,
|
||||
"high": 2,
|
||||
"low": 0.5,
|
||||
"close": 1.5,
|
||||
"volume": 1,
|
||||
}
|
||||
)
|
||||
for i in range(5):
|
||||
bars.append(
|
||||
{
|
||||
"open_time_ms": base_new + i * period,
|
||||
"open": 2,
|
||||
"high": 3,
|
||||
"low": 1.5,
|
||||
"close": 2.5,
|
||||
"volume": 2,
|
||||
}
|
||||
)
|
||||
trimmed, split = trim_contiguous_tail(bars, period)
|
||||
self.assertEqual(split, 3)
|
||||
self.assertEqual(len(trimmed), 5)
|
||||
self.assertEqual(trimmed[0]["open_time_ms"], base_new)
|
||||
|
||||
def test_resolve_drops_discontinuous_orphans(self):
|
||||
init_db(self.db)
|
||||
period = TIMEFRAME_MS["15m"]
|
||||
now = int(time.time() * 1000)
|
||||
old_ms = now - period * 800
|
||||
upsert_bars(
|
||||
"okx",
|
||||
"ONDO/USDT",
|
||||
"15m",
|
||||
[
|
||||
{
|
||||
"open_time_ms": old_ms,
|
||||
"open": 0.33,
|
||||
"high": 0.34,
|
||||
"low": 0.32,
|
||||
"close": 0.335,
|
||||
"volume": 100,
|
||||
}
|
||||
],
|
||||
self.db,
|
||||
)
|
||||
recent = []
|
||||
start = now - period * 20
|
||||
for i in range(20):
|
||||
recent.append(
|
||||
{
|
||||
"open_time_ms": start + i * period,
|
||||
"open": 0.35,
|
||||
"high": 0.36,
|
||||
"low": 0.34,
|
||||
"close": 0.355,
|
||||
"volume": 50,
|
||||
}
|
||||
)
|
||||
|
||||
def remote_fetch(**kwargs):
|
||||
return {"ok": True, "bars": recent, "price_tick": 0.0001}
|
||||
|
||||
out = resolve_chart_bars(
|
||||
"okx",
|
||||
"ONDO/USDT",
|
||||
"15m",
|
||||
remote_fetch,
|
||||
db_path=self.db,
|
||||
limit=50,
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
candles = out.get("candles") or []
|
||||
self.assertGreaterEqual(len(candles), 19)
|
||||
if len(candles) >= 2:
|
||||
for i in range(1, len(candles)):
|
||||
gap = candles[i]["time"] - candles[i - 1]["time"]
|
||||
self.assertLessEqual(gap, int(period / 1000 * 3.0))
|
||||
|
||||
def test_resolve_refetches_when_db_has_discontinuous_full_count(self):
|
||||
init_db(self.db)
|
||||
period = TIMEFRAME_MS["15m"]
|
||||
now = int(time.time() * 1000)
|
||||
old_start = now - period * 3000
|
||||
recent_start = now - period * 25
|
||||
old_bars = [
|
||||
{
|
||||
"open_time_ms": old_start + i * period,
|
||||
"open": 62000,
|
||||
"high": 62100,
|
||||
"low": 61900,
|
||||
"close": 62050,
|
||||
"volume": 10,
|
||||
}
|
||||
for i in range(500)
|
||||
]
|
||||
recent = [
|
||||
{
|
||||
"open_time_ms": recent_start + i * period,
|
||||
"open": 104000,
|
||||
"high": 104100,
|
||||
"low": 103900,
|
||||
"close": 104050,
|
||||
"volume": 20,
|
||||
}
|
||||
for i in range(30)
|
||||
]
|
||||
upsert_bars("binance", "BTC/USDT", "15m", old_bars, self.db)
|
||||
upsert_bars("binance", "BTC/USDT", "15m", recent, self.db)
|
||||
fetch_calls = []
|
||||
|
||||
def remote_fetch(**kwargs):
|
||||
fetch_calls.append(dict(kwargs))
|
||||
full = []
|
||||
start = now - period * 120
|
||||
for i in range(120):
|
||||
full.append(
|
||||
{
|
||||
"open_time_ms": start + i * period,
|
||||
"open": 104000 + i,
|
||||
"high": 104100 + i,
|
||||
"low": 103900 + i,
|
||||
"close": 104050 + i,
|
||||
"volume": 30,
|
||||
}
|
||||
)
|
||||
return {"ok": True, "bars": full, "price_tick": 0.01}
|
||||
|
||||
out = resolve_chart_bars(
|
||||
"binance",
|
||||
"BTC/USDT",
|
||||
"15m",
|
||||
remote_fetch,
|
||||
db_path=self.db,
|
||||
limit=2000,
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertGreater(len(fetch_calls), 0)
|
||||
self.assertGreaterEqual(len(out.get("candles") or []), 100)
|
||||
self.assertGreater(int(out.get("fetched") or 0), 0)
|
||||
|
||||
def test_clear_series_and_force_refetch(self):
|
||||
init_db(self.db)
|
||||
period = TIMEFRAME_MS["5m"]
|
||||
now = int(time.time() * 1000)
|
||||
stale = [
|
||||
{
|
||||
"open_time_ms": now - period * (i + 100),
|
||||
"open": 1,
|
||||
"high": 2,
|
||||
"low": 0.5,
|
||||
"close": 1.5,
|
||||
"volume": 1,
|
||||
}
|
||||
for i in range(40)
|
||||
]
|
||||
upsert_bars("binance", "BTC/USDT", "5m", stale, self.db)
|
||||
self.assertEqual(len(load_bars_latest("binance", "BTC/USDT", "5m", 100, self.db)), 40)
|
||||
removed = clear_series_bars("binance", "BTC/USDT", "5m", self.db)
|
||||
self.assertEqual(removed, 40)
|
||||
self.assertEqual(len(load_bars_latest("binance", "BTC/USDT", "5m", 100, self.db)), 0)
|
||||
|
||||
fresh = [
|
||||
{
|
||||
"open_time_ms": now - period * (20 - i),
|
||||
"open": 10,
|
||||
"high": 11,
|
||||
"low": 9,
|
||||
"close": 10.5,
|
||||
"volume": 2,
|
||||
}
|
||||
for i in range(20)
|
||||
]
|
||||
|
||||
def remote_fetch(**kwargs):
|
||||
return {"ok": True, "bars": fresh, "price_tick": 0.01}
|
||||
|
||||
out = resolve_chart_bars(
|
||||
"binance",
|
||||
"BTC/USDT",
|
||||
"5m",
|
||||
remote_fetch,
|
||||
db_path=self.db,
|
||||
force_refresh=True,
|
||||
clear_db=True,
|
||||
limit=50,
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertGreaterEqual(int(out.get("cleared") or 0), 0)
|
||||
self.assertGreater(int(out.get("fetched") or 0), 0)
|
||||
self.assertGreaterEqual(len(out.get("candles") or []), 19)
|
||||
|
||||
def test_since_span_matches_fetch_limit_not_need(self):
|
||||
period = TIMEFRAME_MS["15m"]
|
||||
now_ms = 1_800_000_000_000
|
||||
fetch_limit = HUB_KLINE_REMOTE_FETCH_CAP
|
||||
since = _since_ms_for_span(
|
||||
now_ms=now_ms,
|
||||
period_ms=period,
|
||||
span_bars=fetch_limit,
|
||||
cutoff_ms=0,
|
||||
)
|
||||
self.assertEqual(since, now_ms - period * fetch_limit)
|
||||
wrong_since = now_ms - period * chart_initial_limit("15m")
|
||||
self.assertGreater(since, wrong_since)
|
||||
|
||||
def test_thin_series_tail_refresh_fetches_full_window(self):
|
||||
init_db(self.db)
|
||||
period = TIMEFRAME_MS["15m"]
|
||||
now = int(time.time() * 1000)
|
||||
last_closed = last_closed_bar_open_ms("15m", now)
|
||||
bars = [
|
||||
{
|
||||
"open_time_ms": last_closed - period * (150 - i),
|
||||
"open": 100000,
|
||||
"high": 100100,
|
||||
"low": 99900,
|
||||
"close": 100050,
|
||||
"volume": 1,
|
||||
}
|
||||
for i in range(150)
|
||||
]
|
||||
fetch_calls: list[dict] = []
|
||||
|
||||
def remote_fetch(**kwargs):
|
||||
fetch_calls.append(dict(kwargs))
|
||||
return {"ok": True, "bars": bars, "price_tick": 0.01}
|
||||
|
||||
out = resolve_chart_bars(
|
||||
"binance",
|
||||
"BTC/USDT",
|
||||
"15m",
|
||||
remote_fetch,
|
||||
db_path=self.db,
|
||||
tail_refresh=True,
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertGreaterEqual(len(out.get("candles") or []), 100)
|
||||
self.assertGreater(int(out.get("fetched") or 0), 0)
|
||||
self.assertTrue(any(int(c.get("limit") or 0) > 30 for c in fetch_calls))
|
||||
|
||||
def test_resolve_before_ms_exhausted(self):
|
||||
init_db(self.db)
|
||||
|
||||
def remote_fetch(**kwargs):
|
||||
return {"ok": False, "msg": "no remote"}
|
||||
|
||||
out = resolve_chart_bars(
|
||||
"okx",
|
||||
"BTC/USDT",
|
||||
"5m",
|
||||
remote_fetch,
|
||||
db_path=self.db,
|
||||
limit=100,
|
||||
before_ms=int(time.time() * 1000),
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertEqual(out.get("candles"), [])
|
||||
self.assertTrue(out.get("exhausted"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""中控 K 线库:分周期保留、聚合与分页读取。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from lib.hub.hub_kline_store import (
|
||||
HUB_KLINE_REMOTE_FETCH_CAP,
|
||||
_since_ms_for_span,
|
||||
clear_series_bars,
|
||||
init_db,
|
||||
load_bars_before,
|
||||
load_bars_latest,
|
||||
purge_retention,
|
||||
purge_timeframe_by_days,
|
||||
resolve_chart_bars,
|
||||
retention_days,
|
||||
trim_contiguous_tail,
|
||||
upsert_bars,
|
||||
)
|
||||
from lib.hub.hub_ohlcv_lib import (
|
||||
TIMEFRAME_MS,
|
||||
bar_limit_for_timeframe,
|
||||
chart_fetch_start_ms,
|
||||
chart_initial_limit,
|
||||
last_closed_bar_open_ms,
|
||||
window_start_ms,
|
||||
)
|
||||
|
||||
|
||||
class TestHubKlineStore(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmp = tempfile.TemporaryDirectory()
|
||||
self.db = Path(self.tmp.name) / "test_hub_kline.db"
|
||||
|
||||
def tearDown(self):
|
||||
self.tmp.cleanup()
|
||||
|
||||
def test_bar_limits(self):
|
||||
self.assertEqual(bar_limit_for_timeframe("5m"), 5000)
|
||||
self.assertEqual(bar_limit_for_timeframe("1h"), 1000)
|
||||
self.assertEqual(bar_limit_for_timeframe("1d"), 1000)
|
||||
self.assertEqual(bar_limit_for_timeframe("1w"), 500)
|
||||
self.assertEqual(chart_initial_limit("5m"), 2000)
|
||||
self.assertEqual(chart_initial_limit("1h"), 1000)
|
||||
self.assertEqual(chart_initial_limit("1d"), 500)
|
||||
|
||||
def test_chart_fetch_window_exceeds_retention(self):
|
||||
now = int(time.time() * 1000)
|
||||
need = bar_limit_for_timeframe("1d")
|
||||
fetch_start = chart_fetch_start_ms("1d", need, now)
|
||||
db_start = window_start_ms("1d", need, retention_days(), now)
|
||||
self.assertLess(fetch_start, db_start)
|
||||
|
||||
def test_purge_retention_5m_one_year(self):
|
||||
init_db(self.db)
|
||||
old_ms = int(time.time() * 1000) - 400 * 86400000
|
||||
upsert_bars(
|
||||
"okx",
|
||||
"BTC/USDT",
|
||||
"5m",
|
||||
[
|
||||
{
|
||||
"open_time_ms": old_ms,
|
||||
"open": 1,
|
||||
"high": 2,
|
||||
"low": 0.5,
|
||||
"close": 1.5,
|
||||
"volume": 10,
|
||||
}
|
||||
],
|
||||
self.db,
|
||||
)
|
||||
n = purge_timeframe_by_days("5m", 365, self.db)
|
||||
self.assertGreaterEqual(n, 1)
|
||||
rows = load_bars_latest("okx", "BTC/USDT", "5m", 10, self.db)
|
||||
self.assertEqual(len(rows), 0)
|
||||
|
||||
def test_purge_retention_keeps_1d(self):
|
||||
init_db(self.db)
|
||||
old_ms = int(time.time() * 1000) - 400 * 86400000
|
||||
upsert_bars(
|
||||
"okx",
|
||||
"BTC/USDT",
|
||||
"1d",
|
||||
[
|
||||
{
|
||||
"open_time_ms": old_ms,
|
||||
"open": 1,
|
||||
"high": 2,
|
||||
"low": 0.5,
|
||||
"close": 1.5,
|
||||
"volume": 10,
|
||||
}
|
||||
],
|
||||
self.db,
|
||||
)
|
||||
purge_retention(self.db)
|
||||
rows = load_bars_latest("okx", "BTC/USDT", "1d", 10, self.db)
|
||||
self.assertEqual(len(rows), 1)
|
||||
|
||||
def test_resolve_uses_cache_without_remote(self):
|
||||
init_db(self.db)
|
||||
now = int(time.time() * 1000)
|
||||
tf = "5m"
|
||||
period = TIMEFRAME_MS[tf]
|
||||
last_closed = last_closed_bar_open_ms(tf, now)
|
||||
bars = []
|
||||
for i in range(400):
|
||||
oms = last_closed - (399 - i) * period
|
||||
bars.append(
|
||||
{
|
||||
"open_time_ms": oms,
|
||||
"open": 100 + i,
|
||||
"high": 101 + i,
|
||||
"low": 99 + i,
|
||||
"close": 100.5 + i,
|
||||
"volume": 1000 + i,
|
||||
}
|
||||
)
|
||||
upsert_bars("okx", "ETH/USDT", tf, bars, self.db)
|
||||
|
||||
def remote_fetch(**kwargs):
|
||||
self.fail("不应请求交易所")
|
||||
|
||||
out = resolve_chart_bars(
|
||||
"okx",
|
||||
"ETH/USDT",
|
||||
tf,
|
||||
remote_fetch,
|
||||
db_path=self.db,
|
||||
limit=300,
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertEqual(len(out.get("candles") or []), 300)
|
||||
|
||||
def test_resolve_15m_reads_native_bars(self):
|
||||
init_db(self.db)
|
||||
now = int(time.time() * 1000)
|
||||
period = TIMEFRAME_MS["15m"]
|
||||
last_closed = last_closed_bar_open_ms("15m", now)
|
||||
bars = []
|
||||
for i in range(12):
|
||||
oms = last_closed - (11 - i) * period
|
||||
bars.append(
|
||||
{
|
||||
"open_time_ms": oms,
|
||||
"open": 1.0 + i,
|
||||
"high": 2.0 + i,
|
||||
"low": 0.5 + i,
|
||||
"close": 1.5 + i,
|
||||
"volume": 10.0,
|
||||
}
|
||||
)
|
||||
upsert_bars("okx", "ETH/USDT", "15m", bars, self.db)
|
||||
|
||||
def remote_fetch(**kwargs):
|
||||
self.fail("不应请求交易所")
|
||||
|
||||
out = resolve_chart_bars(
|
||||
"okx",
|
||||
"ETH/USDT",
|
||||
"15m",
|
||||
remote_fetch,
|
||||
db_path=self.db,
|
||||
limit=10,
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertEqual(out.get("source"), "db")
|
||||
self.assertEqual(out.get("storage_timeframe"), "15m")
|
||||
self.assertGreaterEqual(len(out.get("candles") or []), 10)
|
||||
|
||||
def test_load_bars_before(self):
|
||||
init_db(self.db)
|
||||
period = TIMEFRAME_MS["1h"]
|
||||
base = 1_700_000_000_000
|
||||
bars = []
|
||||
for i in range(5):
|
||||
bars.append(
|
||||
{
|
||||
"open_time_ms": base + i * period,
|
||||
"open": 1,
|
||||
"high": 2,
|
||||
"low": 0.5,
|
||||
"close": 1.5,
|
||||
"volume": 1,
|
||||
}
|
||||
)
|
||||
upsert_bars("okx", "BTC/USDT", "1h", bars, self.db)
|
||||
before = base + 3 * period
|
||||
got = load_bars_before("okx", "BTC/USDT", "1h", before, 2, self.db)
|
||||
self.assertEqual(len(got), 2)
|
||||
self.assertEqual(got[-1]["open_time_ms"], base + 2 * period)
|
||||
|
||||
def test_trim_contiguous_tail_drops_orphan_prefix(self):
|
||||
period = TIMEFRAME_MS["15m"]
|
||||
base_old = 1_700_000_000_000
|
||||
base_new = base_old + period * 500
|
||||
bars = []
|
||||
for i in range(3):
|
||||
bars.append(
|
||||
{
|
||||
"open_time_ms": base_old + i * period,
|
||||
"open": 1,
|
||||
"high": 2,
|
||||
"low": 0.5,
|
||||
"close": 1.5,
|
||||
"volume": 1,
|
||||
}
|
||||
)
|
||||
for i in range(5):
|
||||
bars.append(
|
||||
{
|
||||
"open_time_ms": base_new + i * period,
|
||||
"open": 2,
|
||||
"high": 3,
|
||||
"low": 1.5,
|
||||
"close": 2.5,
|
||||
"volume": 2,
|
||||
}
|
||||
)
|
||||
trimmed, split = trim_contiguous_tail(bars, period)
|
||||
self.assertEqual(split, 3)
|
||||
self.assertEqual(len(trimmed), 5)
|
||||
self.assertEqual(trimmed[0]["open_time_ms"], base_new)
|
||||
|
||||
def test_resolve_drops_discontinuous_orphans(self):
|
||||
init_db(self.db)
|
||||
period = TIMEFRAME_MS["15m"]
|
||||
now = int(time.time() * 1000)
|
||||
old_ms = now - period * 800
|
||||
upsert_bars(
|
||||
"okx",
|
||||
"ONDO/USDT",
|
||||
"15m",
|
||||
[
|
||||
{
|
||||
"open_time_ms": old_ms,
|
||||
"open": 0.33,
|
||||
"high": 0.34,
|
||||
"low": 0.32,
|
||||
"close": 0.335,
|
||||
"volume": 100,
|
||||
}
|
||||
],
|
||||
self.db,
|
||||
)
|
||||
recent = []
|
||||
start = now - period * 20
|
||||
for i in range(20):
|
||||
recent.append(
|
||||
{
|
||||
"open_time_ms": start + i * period,
|
||||
"open": 0.35,
|
||||
"high": 0.36,
|
||||
"low": 0.34,
|
||||
"close": 0.355,
|
||||
"volume": 50,
|
||||
}
|
||||
)
|
||||
|
||||
def remote_fetch(**kwargs):
|
||||
return {"ok": True, "bars": recent, "price_tick": 0.0001}
|
||||
|
||||
out = resolve_chart_bars(
|
||||
"okx",
|
||||
"ONDO/USDT",
|
||||
"15m",
|
||||
remote_fetch,
|
||||
db_path=self.db,
|
||||
limit=50,
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
candles = out.get("candles") or []
|
||||
self.assertGreaterEqual(len(candles), 19)
|
||||
if len(candles) >= 2:
|
||||
for i in range(1, len(candles)):
|
||||
gap = candles[i]["time"] - candles[i - 1]["time"]
|
||||
self.assertLessEqual(gap, int(period / 1000 * 3.0))
|
||||
|
||||
def test_resolve_refetches_when_db_has_discontinuous_full_count(self):
|
||||
init_db(self.db)
|
||||
period = TIMEFRAME_MS["15m"]
|
||||
now = int(time.time() * 1000)
|
||||
old_start = now - period * 3000
|
||||
recent_start = now - period * 25
|
||||
old_bars = [
|
||||
{
|
||||
"open_time_ms": old_start + i * period,
|
||||
"open": 62000,
|
||||
"high": 62100,
|
||||
"low": 61900,
|
||||
"close": 62050,
|
||||
"volume": 10,
|
||||
}
|
||||
for i in range(500)
|
||||
]
|
||||
recent = [
|
||||
{
|
||||
"open_time_ms": recent_start + i * period,
|
||||
"open": 104000,
|
||||
"high": 104100,
|
||||
"low": 103900,
|
||||
"close": 104050,
|
||||
"volume": 20,
|
||||
}
|
||||
for i in range(30)
|
||||
]
|
||||
upsert_bars("binance", "BTC/USDT", "15m", old_bars, self.db)
|
||||
upsert_bars("binance", "BTC/USDT", "15m", recent, self.db)
|
||||
fetch_calls = []
|
||||
|
||||
def remote_fetch(**kwargs):
|
||||
fetch_calls.append(dict(kwargs))
|
||||
full = []
|
||||
start = now - period * 120
|
||||
for i in range(120):
|
||||
full.append(
|
||||
{
|
||||
"open_time_ms": start + i * period,
|
||||
"open": 104000 + i,
|
||||
"high": 104100 + i,
|
||||
"low": 103900 + i,
|
||||
"close": 104050 + i,
|
||||
"volume": 30,
|
||||
}
|
||||
)
|
||||
return {"ok": True, "bars": full, "price_tick": 0.01}
|
||||
|
||||
out = resolve_chart_bars(
|
||||
"binance",
|
||||
"BTC/USDT",
|
||||
"15m",
|
||||
remote_fetch,
|
||||
db_path=self.db,
|
||||
limit=2000,
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertGreater(len(fetch_calls), 0)
|
||||
self.assertGreaterEqual(len(out.get("candles") or []), 100)
|
||||
self.assertGreater(int(out.get("fetched") or 0), 0)
|
||||
|
||||
def test_clear_series_and_force_refetch(self):
|
||||
init_db(self.db)
|
||||
period = TIMEFRAME_MS["5m"]
|
||||
now = int(time.time() * 1000)
|
||||
stale = [
|
||||
{
|
||||
"open_time_ms": now - period * (i + 100),
|
||||
"open": 1,
|
||||
"high": 2,
|
||||
"low": 0.5,
|
||||
"close": 1.5,
|
||||
"volume": 1,
|
||||
}
|
||||
for i in range(40)
|
||||
]
|
||||
upsert_bars("binance", "BTC/USDT", "5m", stale, self.db)
|
||||
self.assertEqual(len(load_bars_latest("binance", "BTC/USDT", "5m", 100, self.db)), 40)
|
||||
removed = clear_series_bars("binance", "BTC/USDT", "5m", self.db)
|
||||
self.assertEqual(removed, 40)
|
||||
self.assertEqual(len(load_bars_latest("binance", "BTC/USDT", "5m", 100, self.db)), 0)
|
||||
|
||||
fresh = [
|
||||
{
|
||||
"open_time_ms": now - period * (20 - i),
|
||||
"open": 10,
|
||||
"high": 11,
|
||||
"low": 9,
|
||||
"close": 10.5,
|
||||
"volume": 2,
|
||||
}
|
||||
for i in range(20)
|
||||
]
|
||||
|
||||
def remote_fetch(**kwargs):
|
||||
return {"ok": True, "bars": fresh, "price_tick": 0.01}
|
||||
|
||||
out = resolve_chart_bars(
|
||||
"binance",
|
||||
"BTC/USDT",
|
||||
"5m",
|
||||
remote_fetch,
|
||||
db_path=self.db,
|
||||
force_refresh=True,
|
||||
clear_db=True,
|
||||
limit=50,
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertGreaterEqual(int(out.get("cleared") or 0), 0)
|
||||
self.assertGreater(int(out.get("fetched") or 0), 0)
|
||||
self.assertGreaterEqual(len(out.get("candles") or []), 19)
|
||||
|
||||
def test_since_span_matches_fetch_limit_not_need(self):
|
||||
period = TIMEFRAME_MS["15m"]
|
||||
now_ms = 1_800_000_000_000
|
||||
fetch_limit = HUB_KLINE_REMOTE_FETCH_CAP
|
||||
since = _since_ms_for_span(
|
||||
now_ms=now_ms,
|
||||
period_ms=period,
|
||||
span_bars=fetch_limit,
|
||||
cutoff_ms=0,
|
||||
)
|
||||
self.assertEqual(since, now_ms - period * fetch_limit)
|
||||
wrong_since = now_ms - period * chart_initial_limit("15m")
|
||||
self.assertGreater(since, wrong_since)
|
||||
|
||||
def test_thin_series_tail_refresh_fetches_full_window(self):
|
||||
init_db(self.db)
|
||||
period = TIMEFRAME_MS["15m"]
|
||||
now = int(time.time() * 1000)
|
||||
last_closed = last_closed_bar_open_ms("15m", now)
|
||||
bars = [
|
||||
{
|
||||
"open_time_ms": last_closed - period * (150 - i),
|
||||
"open": 100000,
|
||||
"high": 100100,
|
||||
"low": 99900,
|
||||
"close": 100050,
|
||||
"volume": 1,
|
||||
}
|
||||
for i in range(150)
|
||||
]
|
||||
fetch_calls: list[dict] = []
|
||||
|
||||
def remote_fetch(**kwargs):
|
||||
fetch_calls.append(dict(kwargs))
|
||||
return {"ok": True, "bars": bars, "price_tick": 0.01}
|
||||
|
||||
out = resolve_chart_bars(
|
||||
"binance",
|
||||
"BTC/USDT",
|
||||
"15m",
|
||||
remote_fetch,
|
||||
db_path=self.db,
|
||||
tail_refresh=True,
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertGreaterEqual(len(out.get("candles") or []), 100)
|
||||
self.assertGreater(int(out.get("fetched") or 0), 0)
|
||||
self.assertTrue(any(int(c.get("limit") or 0) > 30 for c in fetch_calls))
|
||||
|
||||
def test_resolve_before_ms_exhausted(self):
|
||||
init_db(self.db)
|
||||
|
||||
def remote_fetch(**kwargs):
|
||||
return {"ok": False, "msg": "no remote"}
|
||||
|
||||
out = resolve_chart_bars(
|
||||
"okx",
|
||||
"BTC/USDT",
|
||||
"5m",
|
||||
remote_fetch,
|
||||
db_path=self.db,
|
||||
limit=100,
|
||||
before_ms=int(time.time() * 1000),
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertEqual(out.get("candles"), [])
|
||||
self.assertTrue(out.get("exhausted"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,73 +1,73 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
from hub_macro_calendar_lib import (
|
||||
build_banner_message,
|
||||
create_event,
|
||||
delete_event,
|
||||
enrich_alert,
|
||||
init_db,
|
||||
list_active_alerts,
|
||||
list_events,
|
||||
update_event,
|
||||
)
|
||||
|
||||
|
||||
class HubMacroCalendarLibTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmp = tempfile.TemporaryDirectory()
|
||||
self.db_path = Path(self.tmp.name) / "macro.db"
|
||||
init_db(self.db_path)
|
||||
|
||||
def tearDown(self):
|
||||
self.tmp.cleanup()
|
||||
|
||||
def test_create_and_list(self):
|
||||
row = create_event("cpi", "2026-06-18 20:30", note="核心CPI", db_path=self.db_path)
|
||||
self.assertEqual(row["event_type"], "cpi")
|
||||
self.assertEqual(row["event_at"], "2026-06-18 20:30")
|
||||
rows = list_events(now_ms=row["event_at_ms"] - 86400000, db_path=self.db_path)
|
||||
self.assertEqual(len(rows), 1)
|
||||
|
||||
def test_duplicate_rejected(self):
|
||||
create_event("fomc", "2026-07-01 02:00", db_path=self.db_path)
|
||||
with self.assertRaises(ValueError):
|
||||
create_event("fomc", "2026-07-01 02:00", db_path=self.db_path)
|
||||
|
||||
def test_active_window_and_messages(self):
|
||||
row = create_event("employment", "2026-06-18 20:30", db_path=self.db_path)
|
||||
t0 = int(row["event_at_ms"])
|
||||
inside = enrich_alert(row, now_ms=t0 - 30 * 60 * 1000)
|
||||
self.assertIsNotNone(inside)
|
||||
self.assertEqual(inside["phase"], "imminent")
|
||||
outside = enrich_alert(row, now_ms=t0 - 2 * 3600 * 1000)
|
||||
self.assertIsNone(outside)
|
||||
alerts = list_active_alerts(now_ms=t0 + 15 * 60 * 1000, db_path=self.db_path)
|
||||
self.assertEqual(len(alerts), 1)
|
||||
msg_pos = build_banner_message(alerts[0], has_positions=True)
|
||||
msg_flat = build_banner_message(alerts[0], has_positions=False)
|
||||
self.assertIn("注意仓位风险", msg_pos)
|
||||
self.assertIn("建议等待", msg_flat)
|
||||
|
||||
def test_update_and_delete(self):
|
||||
row = create_event("cpi", "2026-06-18 20:30", db_path=self.db_path)
|
||||
updated = update_event(
|
||||
row["id"],
|
||||
event_at="2026-06-18 21:00",
|
||||
note="修正时间",
|
||||
db_path=self.db_path,
|
||||
)
|
||||
self.assertEqual(updated["event_at"], "2026-06-18 21:00")
|
||||
self.assertTrue(delete_event(row["id"], db_path=self.db_path))
|
||||
self.assertEqual(len(list_events(now_ms=updated["event_at_ms"], db_path=self.db_path)), 0)
|
||||
|
||||
def test_invalid_type(self):
|
||||
with self.assertRaises(ValueError):
|
||||
create_event("nfp", "2026-06-18 20:30", db_path=self.db_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
from lib.hub.hub_macro_calendar_lib import (
|
||||
build_banner_message,
|
||||
create_event,
|
||||
delete_event,
|
||||
enrich_alert,
|
||||
init_db,
|
||||
list_active_alerts,
|
||||
list_events,
|
||||
update_event,
|
||||
)
|
||||
|
||||
|
||||
class HubMacroCalendarLibTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmp = tempfile.TemporaryDirectory()
|
||||
self.db_path = Path(self.tmp.name) / "macro.db"
|
||||
init_db(self.db_path)
|
||||
|
||||
def tearDown(self):
|
||||
self.tmp.cleanup()
|
||||
|
||||
def test_create_and_list(self):
|
||||
row = create_event("cpi", "2026-06-18 20:30", note="核心CPI", db_path=self.db_path)
|
||||
self.assertEqual(row["event_type"], "cpi")
|
||||
self.assertEqual(row["event_at"], "2026-06-18 20:30")
|
||||
rows = list_events(now_ms=row["event_at_ms"] - 86400000, db_path=self.db_path)
|
||||
self.assertEqual(len(rows), 1)
|
||||
|
||||
def test_duplicate_rejected(self):
|
||||
create_event("fomc", "2026-07-01 02:00", db_path=self.db_path)
|
||||
with self.assertRaises(ValueError):
|
||||
create_event("fomc", "2026-07-01 02:00", db_path=self.db_path)
|
||||
|
||||
def test_active_window_and_messages(self):
|
||||
row = create_event("employment", "2026-06-18 20:30", db_path=self.db_path)
|
||||
t0 = int(row["event_at_ms"])
|
||||
inside = enrich_alert(row, now_ms=t0 - 30 * 60 * 1000)
|
||||
self.assertIsNotNone(inside)
|
||||
self.assertEqual(inside["phase"], "imminent")
|
||||
outside = enrich_alert(row, now_ms=t0 - 2 * 3600 * 1000)
|
||||
self.assertIsNone(outside)
|
||||
alerts = list_active_alerts(now_ms=t0 + 15 * 60 * 1000, db_path=self.db_path)
|
||||
self.assertEqual(len(alerts), 1)
|
||||
msg_pos = build_banner_message(alerts[0], has_positions=True)
|
||||
msg_flat = build_banner_message(alerts[0], has_positions=False)
|
||||
self.assertIn("注意仓位风险", msg_pos)
|
||||
self.assertIn("建议等待", msg_flat)
|
||||
|
||||
def test_update_and_delete(self):
|
||||
row = create_event("cpi", "2026-06-18 20:30", db_path=self.db_path)
|
||||
updated = update_event(
|
||||
row["id"],
|
||||
event_at="2026-06-18 21:00",
|
||||
note="修正时间",
|
||||
db_path=self.db_path,
|
||||
)
|
||||
self.assertEqual(updated["event_at"], "2026-06-18 21:00")
|
||||
self.assertTrue(delete_event(row["id"], db_path=self.db_path))
|
||||
self.assertEqual(len(list_events(now_ms=updated["event_at_ms"], db_path=self.db_path)), 0)
|
||||
|
||||
def test_invalid_type(self):
|
||||
with self.assertRaises(ValueError):
|
||||
create_event("nfp", "2026-06-18 20:30", db_path=self.db_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,39 +1,39 @@
|
||||
"""hub /api/hub/monitor:enrich 局部返回时须保留 keys。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from hub_bridge import build_hub_monitor_payload # noqa: E402
|
||||
|
||||
|
||||
class TestHubMonitorPayload(unittest.TestCase):
|
||||
def test_partial_enrich_keeps_keys(self):
|
||||
keys = [{"id": 7, "symbol": "BTC/USDT"}]
|
||||
orders = [{"id": 1}]
|
||||
trends = [{"id": 9, "symbol": "ETH/USDT"}]
|
||||
rolls = []
|
||||
|
||||
def enrich_only_trends(**_kw):
|
||||
return {"trends": [{"id": 9, "add_count": 2}]}
|
||||
|
||||
out = build_hub_monitor_payload(
|
||||
keys=keys,
|
||||
orders=orders,
|
||||
trends=trends,
|
||||
rolls=rolls,
|
||||
enrich=enrich_only_trends,
|
||||
)
|
||||
self.assertTrue(out["ok"])
|
||||
self.assertEqual(out["keys"], keys)
|
||||
self.assertEqual(out["orders"], orders)
|
||||
self.assertEqual(out["rolls"], rolls)
|
||||
self.assertEqual(out["trends"][0]["add_count"], 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""hub /api/hub/monitor:enrich 局部返回时须保留 keys。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from lib.hub.hub_bridge import build_hub_monitor_payload # noqa: E402
|
||||
|
||||
|
||||
class TestHubMonitorPayload(unittest.TestCase):
|
||||
def test_partial_enrich_keeps_keys(self):
|
||||
keys = [{"id": 7, "symbol": "BTC/USDT"}]
|
||||
orders = [{"id": 1}]
|
||||
trends = [{"id": 9, "symbol": "ETH/USDT"}]
|
||||
rolls = []
|
||||
|
||||
def enrich_only_trends(**_kw):
|
||||
return {"trends": [{"id": 9, "add_count": 2}]}
|
||||
|
||||
out = build_hub_monitor_payload(
|
||||
keys=keys,
|
||||
orders=orders,
|
||||
trends=trends,
|
||||
rolls=rolls,
|
||||
enrich=enrich_only_trends,
|
||||
)
|
||||
self.assertTrue(out["ok"])
|
||||
self.assertEqual(out["keys"], keys)
|
||||
self.assertEqual(out["orders"], orders)
|
||||
self.assertEqual(out["rolls"], rolls)
|
||||
self.assertEqual(out["trends"][0]["add_count"], 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
+222
-223
@@ -1,223 +1,222 @@
|
||||
"""hub_ohlcv_lib:分页拉取(Gate 等单次不足 chunk 时仍继续)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from hub_ohlcv_lib import (
|
||||
aggregate_ohlcv_bars,
|
||||
bars_spacing_matches_timeframe,
|
||||
fetch_ohlcv_for_hub,
|
||||
normalize_price_tick,
|
||||
)
|
||||
|
||||
|
||||
class _FakeExchange:
|
||||
def __init__(self, pages, *, timeframes=None):
|
||||
self.pages = list(pages)
|
||||
self.calls = []
|
||||
self.markets = {}
|
||||
self.timeframes = timeframes if timeframes is not None else {}
|
||||
|
||||
def fetch_ohlcv(self, symbol, timeframe=None, since=None, limit=None):
|
||||
self.calls.append(
|
||||
{"symbol": symbol, "since": since, "limit": limit, "timeframe": timeframe}
|
||||
)
|
||||
if not self.pages:
|
||||
return []
|
||||
page = self.pages.pop(0)
|
||||
if since is None:
|
||||
return page
|
||||
return [b for b in page if b[0] >= since]
|
||||
|
||||
|
||||
class TestHubOhlcvLib(unittest.TestCase):
|
||||
def test_normalize_price_tick_snaps_powers_of_ten(self):
|
||||
self.assertAlmostEqual(normalize_price_tick(0.00001), 0.00001)
|
||||
self.assertAlmostEqual(normalize_price_tick(0.001), 0.001)
|
||||
self.assertIsNone(normalize_price_tick(0))
|
||||
|
||||
def test_price_tick_from_decimal_precision(self):
|
||||
class _Ex:
|
||||
markets = {"BTC/USDT:USDT": {"precision": {"price": 2}, "info": {}, "limits": {}}}
|
||||
|
||||
def load_markets(self):
|
||||
return self.markets
|
||||
|
||||
def market(self, sym):
|
||||
return self.markets[sym]
|
||||
|
||||
def price_to_precision(self, sym, price):
|
||||
return "12345.67"
|
||||
|
||||
tick = __import__("hub_ohlcv_lib", fromlist=["price_tick_from_market"]).price_tick_from_market(
|
||||
_Ex(), "BTC/USDT:USDT"
|
||||
)
|
||||
self.assertAlmostEqual(tick, 0.01)
|
||||
|
||||
def test_price_tick_from_binance_price_filter(self):
|
||||
class _Ex:
|
||||
markets = {
|
||||
"BTC/USDT:USDT": {
|
||||
"precision": {"price": 2},
|
||||
"info": {
|
||||
"filters": [
|
||||
{"filterType": "PRICE_FILTER", "tickSize": "0.10"},
|
||||
{"filterType": "LOT_SIZE", "stepSize": "0.001"},
|
||||
]
|
||||
},
|
||||
"limits": {},
|
||||
}
|
||||
}
|
||||
|
||||
def load_markets(self):
|
||||
return self.markets
|
||||
|
||||
def market(self, sym):
|
||||
return self.markets[sym]
|
||||
|
||||
def price_to_precision(self, sym, price):
|
||||
return "12345.6"
|
||||
|
||||
from hub_ohlcv_lib import price_tick_from_market
|
||||
|
||||
tick = price_tick_from_market(_Ex(), "BTC/USDT:USDT")
|
||||
self.assertAlmostEqual(tick, 0.10)
|
||||
|
||||
def test_price_tick_from_info_tick_size(self):
|
||||
class _Ex:
|
||||
markets = {
|
||||
"INJ/USDT:USDT": {
|
||||
"precision": {"price": 4},
|
||||
"info": {"tickSize": "0.001"},
|
||||
"limits": {},
|
||||
}
|
||||
}
|
||||
|
||||
def load_markets(self):
|
||||
return self.markets
|
||||
|
||||
def market(self, sym):
|
||||
return self.markets[sym]
|
||||
|
||||
def price_to_precision(self, sym, price):
|
||||
return "7.123"
|
||||
|
||||
from hub_ohlcv_lib import price_tick_from_market
|
||||
|
||||
tick = price_tick_from_market(_Ex(), "INJ/USDT:USDT")
|
||||
self.assertAlmostEqual(tick, 0.001)
|
||||
|
||||
def test_full_fetch_without_since_paginates_okx_style(self):
|
||||
"""OKX 等无 since 单次约 300 根,须分页至 limit。"""
|
||||
from hub_ohlcv_lib import TIMEFRAME_MS
|
||||
|
||||
step = TIMEFRAME_MS["1h"]
|
||||
want = 1000
|
||||
base = max(0, int(__import__("time").time() * 1000) - want * step)
|
||||
pages = [
|
||||
[[base + i * step, 1.0, 1.1, 0.9, 1.05, 100.0] for i in range(300)],
|
||||
[[base + (300 + i) * step, 2.0, 2.1, 1.9, 2.05, 200.0] for i in range(300)],
|
||||
[[base + (600 + i) * step, 3.0, 3.1, 2.9, 3.05, 300.0] for i in range(300)],
|
||||
[[base + (900 + i) * step, 4.0, 4.1, 3.9, 4.05, 400.0] for i in range(100)],
|
||||
]
|
||||
ex = _FakeExchange(pages)
|
||||
|
||||
out = fetch_ohlcv_for_hub(
|
||||
symbol="ONDO/USDT",
|
||||
timeframe="1h",
|
||||
since_ms=None,
|
||||
limit=want,
|
||||
normalize_symbol_input=lambda s: str(s).strip().upper(),
|
||||
normalize_exchange_symbol=lambda s: f"{s}:USDT" if ":" not in s else s,
|
||||
ensure_markets_loaded=lambda: None,
|
||||
exchange=ex,
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertEqual(len(out.get("bars") or []), 1000)
|
||||
self.assertGreaterEqual(len(ex.calls), 4)
|
||||
self.assertAlmostEqual(out["bars"][-1]["close"], 4.05)
|
||||
|
||||
def test_pagination_continues_when_page_smaller_than_chunk(self):
|
||||
"""Gate 等常返回 299 根/次,不应误判为已到末尾。"""
|
||||
base = 1_700_000_000_000
|
||||
step = 4 * 60 * 60 * 1000
|
||||
page1 = [
|
||||
[base + i * step, 1.0, 1.1, 0.9, 1.05, 100.0] for i in range(299)
|
||||
]
|
||||
page2 = [
|
||||
[base + (299 + i) * step, 2.0, 2.1, 1.9, 2.05, 200.0] for i in range(299)
|
||||
]
|
||||
page3 = [
|
||||
[base + (598 + i) * step, 3.0, 3.1, 2.9, 3.05, 300.0] for i in range(50)
|
||||
]
|
||||
ex = _FakeExchange([page1, page2, page3])
|
||||
|
||||
out = fetch_ohlcv_for_hub(
|
||||
symbol="INJ/USDT",
|
||||
timeframe="4h",
|
||||
since_ms=base,
|
||||
limit=600,
|
||||
normalize_symbol_input=lambda s: str(s).strip().upper(),
|
||||
normalize_exchange_symbol=lambda s: f"{s}:USDT" if ":" not in s else s,
|
||||
ensure_markets_loaded=lambda: None,
|
||||
exchange=ex,
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertEqual(len(out.get("bars") or []), 600)
|
||||
self.assertGreaterEqual(len(ex.calls), 3)
|
||||
self.assertAlmostEqual(out["bars"][-1]["close"], 3.05)
|
||||
|
||||
def test_pagination_stops_when_next_since_reaches_now(self):
|
||||
"""Gate 等:分页 since 不得越过当前时间,避免 from>to。"""
|
||||
from hub_ohlcv_lib import TIMEFRAME_MS
|
||||
|
||||
step = TIMEFRAME_MS["1d"]
|
||||
now_ms = int(__import__("time").time() * 1000)
|
||||
# 最后一页最后一根 K 的 next_since 将 >= now_ms,应停止不再请求
|
||||
last_open = ((now_ms // step) - 2) * step
|
||||
page = [
|
||||
[last_open - step, 1.0, 1.1, 0.9, 1.0, 10.0],
|
||||
[last_open, 1.1, 1.2, 1.0, 1.1, 11.0],
|
||||
]
|
||||
ex = _FakeExchange([page])
|
||||
|
||||
out = fetch_ohlcv_for_hub(
|
||||
symbol="ONDO/USDT",
|
||||
timeframe="1d",
|
||||
since_ms=last_open - step * 5,
|
||||
limit=10,
|
||||
normalize_symbol_input=lambda s: str(s).strip().upper(),
|
||||
normalize_exchange_symbol=lambda s: f"{s}:USDT" if ":" not in s else s,
|
||||
ensure_markets_loaded=lambda: None,
|
||||
exchange=ex,
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertGreaterEqual(len(out.get("bars") or []), 2)
|
||||
self.assertLessEqual(len(ex.calls), 4)
|
||||
|
||||
def test_aggregate_ohlcv_bars_buckets(self):
|
||||
from hub_ohlcv_lib import TIMEFRAME_MS
|
||||
|
||||
h1 = TIMEFRAME_MS["1h"]
|
||||
h4 = TIMEFRAME_MS["4h"]
|
||||
base = (1_700_000_000_000 // h4) * h4
|
||||
src = [
|
||||
{
|
||||
"open_time_ms": base + i * h1,
|
||||
"open": 1.0,
|
||||
"high": 2.0,
|
||||
"low": 0.5,
|
||||
"close": 1.5,
|
||||
"volume": 1.0,
|
||||
}
|
||||
for i in range(4)
|
||||
]
|
||||
out = aggregate_ohlcv_bars(src, "4h")
|
||||
self.assertEqual(len(out), 1)
|
||||
self.assertEqual(out[0]["volume"], 4.0)
|
||||
self.assertEqual(out[0]["high"], 2.0)
|
||||
self.assertEqual(out[0]["low"], 0.5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""hub_ohlcv_lib:分页拉取(Gate 等单次不足 chunk 时仍继续)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from lib.hub.hub_ohlcv_lib import (
|
||||
aggregate_ohlcv_bars,
|
||||
bars_spacing_matches_timeframe,
|
||||
fetch_ohlcv_for_hub,
|
||||
normalize_price_tick,
|
||||
price_tick_from_market,
|
||||
)
|
||||
|
||||
|
||||
class _FakeExchange:
|
||||
def __init__(self, pages, *, timeframes=None):
|
||||
self.pages = list(pages)
|
||||
self.calls = []
|
||||
self.markets = {}
|
||||
self.timeframes = timeframes if timeframes is not None else {}
|
||||
|
||||
def fetch_ohlcv(self, symbol, timeframe=None, since=None, limit=None):
|
||||
self.calls.append(
|
||||
{"symbol": symbol, "since": since, "limit": limit, "timeframe": timeframe}
|
||||
)
|
||||
if not self.pages:
|
||||
return []
|
||||
page = self.pages.pop(0)
|
||||
if since is None:
|
||||
return page
|
||||
return [b for b in page if b[0] >= since]
|
||||
|
||||
|
||||
class TestHubOhlcvLib(unittest.TestCase):
|
||||
def test_normalize_price_tick_snaps_powers_of_ten(self):
|
||||
self.assertAlmostEqual(normalize_price_tick(0.00001), 0.00001)
|
||||
self.assertAlmostEqual(normalize_price_tick(0.001), 0.001)
|
||||
self.assertIsNone(normalize_price_tick(0))
|
||||
|
||||
def test_price_tick_from_decimal_precision(self):
|
||||
class _Ex:
|
||||
markets = {"BTC/USDT:USDT": {"precision": {"price": 2}, "info": {}, "limits": {}}}
|
||||
|
||||
def load_markets(self):
|
||||
return self.markets
|
||||
|
||||
def market(self, sym):
|
||||
return self.markets[sym]
|
||||
|
||||
def price_to_precision(self, sym, price):
|
||||
return "12345.67"
|
||||
|
||||
tick = price_tick_from_market(_Ex(), "BTC/USDT:USDT")
|
||||
self.assertAlmostEqual(tick, 0.01)
|
||||
|
||||
def test_price_tick_from_binance_price_filter(self):
|
||||
class _Ex:
|
||||
markets = {
|
||||
"BTC/USDT:USDT": {
|
||||
"precision": {"price": 2},
|
||||
"info": {
|
||||
"filters": [
|
||||
{"filterType": "PRICE_FILTER", "tickSize": "0.10"},
|
||||
{"filterType": "LOT_SIZE", "stepSize": "0.001"},
|
||||
]
|
||||
},
|
||||
"limits": {},
|
||||
}
|
||||
}
|
||||
|
||||
def load_markets(self):
|
||||
return self.markets
|
||||
|
||||
def market(self, sym):
|
||||
return self.markets[sym]
|
||||
|
||||
def price_to_precision(self, sym, price):
|
||||
return "12345.6"
|
||||
|
||||
from lib.hub.hub_ohlcv_lib import price_tick_from_market
|
||||
|
||||
tick = price_tick_from_market(_Ex(), "BTC/USDT:USDT")
|
||||
self.assertAlmostEqual(tick, 0.10)
|
||||
|
||||
def test_price_tick_from_info_tick_size(self):
|
||||
class _Ex:
|
||||
markets = {
|
||||
"INJ/USDT:USDT": {
|
||||
"precision": {"price": 4},
|
||||
"info": {"tickSize": "0.001"},
|
||||
"limits": {},
|
||||
}
|
||||
}
|
||||
|
||||
def load_markets(self):
|
||||
return self.markets
|
||||
|
||||
def market(self, sym):
|
||||
return self.markets[sym]
|
||||
|
||||
def price_to_precision(self, sym, price):
|
||||
return "7.123"
|
||||
|
||||
from lib.hub.hub_ohlcv_lib import price_tick_from_market
|
||||
|
||||
tick = price_tick_from_market(_Ex(), "INJ/USDT:USDT")
|
||||
self.assertAlmostEqual(tick, 0.001)
|
||||
|
||||
def test_full_fetch_without_since_paginates_okx_style(self):
|
||||
"""OKX 等无 since 单次约 300 根,须分页至 limit。"""
|
||||
from lib.hub.hub_ohlcv_lib import TIMEFRAME_MS
|
||||
|
||||
step = TIMEFRAME_MS["1h"]
|
||||
want = 1000
|
||||
base = max(0, int(__import__("time").time() * 1000) - want * step)
|
||||
pages = [
|
||||
[[base + i * step, 1.0, 1.1, 0.9, 1.05, 100.0] for i in range(300)],
|
||||
[[base + (300 + i) * step, 2.0, 2.1, 1.9, 2.05, 200.0] for i in range(300)],
|
||||
[[base + (600 + i) * step, 3.0, 3.1, 2.9, 3.05, 300.0] for i in range(300)],
|
||||
[[base + (900 + i) * step, 4.0, 4.1, 3.9, 4.05, 400.0] for i in range(100)],
|
||||
]
|
||||
ex = _FakeExchange(pages)
|
||||
|
||||
out = fetch_ohlcv_for_hub(
|
||||
symbol="ONDO/USDT",
|
||||
timeframe="1h",
|
||||
since_ms=None,
|
||||
limit=want,
|
||||
normalize_symbol_input=lambda s: str(s).strip().upper(),
|
||||
normalize_exchange_symbol=lambda s: f"{s}:USDT" if ":" not in s else s,
|
||||
ensure_markets_loaded=lambda: None,
|
||||
exchange=ex,
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertEqual(len(out.get("bars") or []), 1000)
|
||||
self.assertGreaterEqual(len(ex.calls), 4)
|
||||
self.assertAlmostEqual(out["bars"][-1]["close"], 4.05)
|
||||
|
||||
def test_pagination_continues_when_page_smaller_than_chunk(self):
|
||||
"""Gate 等常返回 299 根/次,不应误判为已到末尾。"""
|
||||
base = 1_700_000_000_000
|
||||
step = 4 * 60 * 60 * 1000
|
||||
page1 = [
|
||||
[base + i * step, 1.0, 1.1, 0.9, 1.05, 100.0] for i in range(299)
|
||||
]
|
||||
page2 = [
|
||||
[base + (299 + i) * step, 2.0, 2.1, 1.9, 2.05, 200.0] for i in range(299)
|
||||
]
|
||||
page3 = [
|
||||
[base + (598 + i) * step, 3.0, 3.1, 2.9, 3.05, 300.0] for i in range(50)
|
||||
]
|
||||
ex = _FakeExchange([page1, page2, page3])
|
||||
|
||||
out = fetch_ohlcv_for_hub(
|
||||
symbol="INJ/USDT",
|
||||
timeframe="4h",
|
||||
since_ms=base,
|
||||
limit=600,
|
||||
normalize_symbol_input=lambda s: str(s).strip().upper(),
|
||||
normalize_exchange_symbol=lambda s: f"{s}:USDT" if ":" not in s else s,
|
||||
ensure_markets_loaded=lambda: None,
|
||||
exchange=ex,
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertEqual(len(out.get("bars") or []), 600)
|
||||
self.assertGreaterEqual(len(ex.calls), 3)
|
||||
self.assertAlmostEqual(out["bars"][-1]["close"], 3.05)
|
||||
|
||||
def test_pagination_stops_when_next_since_reaches_now(self):
|
||||
"""Gate 等:分页 since 不得越过当前时间,避免 from>to。"""
|
||||
from lib.hub.hub_ohlcv_lib import TIMEFRAME_MS
|
||||
|
||||
step = TIMEFRAME_MS["1d"]
|
||||
now_ms = int(__import__("time").time() * 1000)
|
||||
# 最后一页最后一根 K 的 next_since 将 >= now_ms,应停止不再请求
|
||||
last_open = ((now_ms // step) - 2) * step
|
||||
page = [
|
||||
[last_open - step, 1.0, 1.1, 0.9, 1.0, 10.0],
|
||||
[last_open, 1.1, 1.2, 1.0, 1.1, 11.0],
|
||||
]
|
||||
ex = _FakeExchange([page])
|
||||
|
||||
out = fetch_ohlcv_for_hub(
|
||||
symbol="ONDO/USDT",
|
||||
timeframe="1d",
|
||||
since_ms=last_open - step * 5,
|
||||
limit=10,
|
||||
normalize_symbol_input=lambda s: str(s).strip().upper(),
|
||||
normalize_exchange_symbol=lambda s: f"{s}:USDT" if ":" not in s else s,
|
||||
ensure_markets_loaded=lambda: None,
|
||||
exchange=ex,
|
||||
)
|
||||
self.assertTrue(out.get("ok"))
|
||||
self.assertGreaterEqual(len(out.get("bars") or []), 2)
|
||||
self.assertLessEqual(len(ex.calls), 4)
|
||||
|
||||
def test_aggregate_ohlcv_bars_buckets(self):
|
||||
from lib.hub.hub_ohlcv_lib import TIMEFRAME_MS
|
||||
|
||||
h1 = TIMEFRAME_MS["1h"]
|
||||
h4 = TIMEFRAME_MS["4h"]
|
||||
base = (1_700_000_000_000 // h4) * h4
|
||||
src = [
|
||||
{
|
||||
"open_time_ms": base + i * h1,
|
||||
"open": 1.0,
|
||||
"high": 2.0,
|
||||
"low": 0.5,
|
||||
"close": 1.5,
|
||||
"volume": 1.0,
|
||||
}
|
||||
for i in range(4)
|
||||
]
|
||||
out = aggregate_ohlcv_bars(src, "4h")
|
||||
self.assertEqual(len(out), 1)
|
||||
self.assertEqual(out[0]["volume"], 4.0)
|
||||
self.assertEqual(out[0]["high"], 2.0)
|
||||
self.assertEqual(out[0]["low"], 0.5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -5,7 +5,12 @@ import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
try:
|
||||
import pytest
|
||||
except ImportError: # pragma: no cover
|
||||
import unittest
|
||||
|
||||
raise unittest.SkipTest("pytest not installed")
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
@@ -1,348 +1,348 @@
|
||||
"""币种档案库:5m 聚合与视窗计算。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from hub_ohlcv_lib import aggregate_ohlcv_bars
|
||||
from datetime import datetime, timezone
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from hub_symbol_archive_lib import (
|
||||
CHART_DISPLAY_TZ,
|
||||
_compute_period_stats,
|
||||
_fill_missing_bars,
|
||||
init_db,
|
||||
list_daily_trades,
|
||||
load_symbol_trades,
|
||||
ms_to_wall_clock_str,
|
||||
parse_wall_clock_ms,
|
||||
resolve_archive_chart,
|
||||
trading_day_bounds_ms,
|
||||
upsert_bars_5m,
|
||||
upsert_trade_overlay,
|
||||
list_symbol_rows,
|
||||
upsert_trades_cache,
|
||||
)
|
||||
|
||||
|
||||
def _seed_5m_bars(
|
||||
db: Path,
|
||||
start_ms: int,
|
||||
count: int,
|
||||
step: int = 300_000,
|
||||
*,
|
||||
ex: str = "gate",
|
||||
sym: str = "ONDO",
|
||||
) -> None:
|
||||
bars = []
|
||||
price = 1.0
|
||||
for i in range(count):
|
||||
o = start_ms + i * step
|
||||
price += 0.001
|
||||
bars.append(
|
||||
{
|
||||
"open_time_ms": o,
|
||||
"open": price,
|
||||
"high": price + 0.002,
|
||||
"low": price - 0.001,
|
||||
"close": price + 0.001,
|
||||
"volume": 100 + i,
|
||||
}
|
||||
)
|
||||
upsert_bars_5m(ex, sym, bars, db_path=db)
|
||||
|
||||
|
||||
def test_aggregate_15m_from_5m():
|
||||
start = 1_700_000_000_000
|
||||
bars = []
|
||||
for i in range(6):
|
||||
t = start + i * 300_000
|
||||
bars.append(
|
||||
{
|
||||
"open_time_ms": t,
|
||||
"open": 1.0,
|
||||
"high": 1.1,
|
||||
"low": 0.9,
|
||||
"close": 1.05,
|
||||
"volume": 10,
|
||||
}
|
||||
)
|
||||
agg = aggregate_ohlcv_bars(bars, "15m")
|
||||
assert len(agg) >= 1
|
||||
assert agg[-1]["close"] == bars[-1]["close"]
|
||||
assert agg[0]["open_time_ms"] <= agg[1]["open_time_ms"]
|
||||
|
||||
|
||||
def test_resolve_archive_chart_15m():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "archive.db"
|
||||
init_db(db)
|
||||
anchor = 1_700_000_000_000
|
||||
_seed_5m_bars(db, anchor - 50 * 300_000, 120)
|
||||
out = resolve_archive_chart(
|
||||
"gate",
|
||||
"ONDO",
|
||||
"15m",
|
||||
anchor_ms=anchor,
|
||||
mode="hold",
|
||||
bars=40,
|
||||
db_path=db,
|
||||
)
|
||||
assert out["ok"] is True
|
||||
assert out["timeframe"] == "15m"
|
||||
assert len(out["candles"]) >= 10
|
||||
|
||||
|
||||
def test_fill_missing_bars_continuity():
|
||||
period = 300_000
|
||||
start = (1_700_000_000_000 // period) * period
|
||||
bars = [
|
||||
{
|
||||
"open_time_ms": start,
|
||||
"open": 1.0,
|
||||
"high": 1.1,
|
||||
"low": 0.9,
|
||||
"close": 1.05,
|
||||
"volume": 10,
|
||||
},
|
||||
{
|
||||
"open_time_ms": start + period * 2,
|
||||
"open": 1.05,
|
||||
"high": 1.15,
|
||||
"low": 1.0,
|
||||
"close": 1.1,
|
||||
"volume": 8,
|
||||
},
|
||||
]
|
||||
filled = _fill_missing_bars(bars, period, start, start + period * 2)
|
||||
assert len(filled) >= 3
|
||||
assert any(b.get("filled") for b in filled)
|
||||
|
||||
|
||||
def test_resolve_archive_chart_history_range():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "archive.db"
|
||||
init_db(db)
|
||||
open_ms = 1_700_000_000_000
|
||||
close_ms = open_ms + 6 * 3600_000
|
||||
_seed_5m_bars(db, open_ms - 20 * 300_000, 200, ex="gate", sym="BNB/USDT")
|
||||
out = resolve_archive_chart(
|
||||
"gate",
|
||||
"BNB/USDT",
|
||||
"15m",
|
||||
opened_ms=open_ms,
|
||||
closed_ms=close_ms,
|
||||
mode="hold",
|
||||
range_mode="history",
|
||||
db_path=db,
|
||||
)
|
||||
assert out["ok"] is True
|
||||
assert out.get("range_mode") == "history"
|
||||
assert out.get("window_end_ms") <= close_ms + 4 * 3600_000
|
||||
assert len(out["candles"]) >= 40
|
||||
|
||||
|
||||
def test_sync_prunes_missing_trades():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "archive.db"
|
||||
init_db(db)
|
||||
upsert_trades_cache(
|
||||
"gate",
|
||||
[
|
||||
{"id": 1, "symbol": "BNB/USDT", "result": "止损", "pnl_amount": -1},
|
||||
{"id": 2, "symbol": "BNB/USDT", "result": "止盈", "pnl_amount": 1},
|
||||
],
|
||||
db_path=db,
|
||||
prune_missing=False,
|
||||
)
|
||||
stats = upsert_trades_cache(
|
||||
"gate",
|
||||
[{"id": 1, "symbol": "BNB/USDT", "result": "止损", "pnl_amount": -1}],
|
||||
db_path=db,
|
||||
prune_missing=True,
|
||||
)
|
||||
rows = load_symbol_trades("gate", "BNB/USDT", db_path=db)
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["trade_id"] == 1
|
||||
assert stats["removed"] == 1
|
||||
|
||||
|
||||
def test_list_with_overlay_filters():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "archive.db"
|
||||
init_db(db)
|
||||
upsert_trades_cache(
|
||||
"gate",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"symbol": "ONDO",
|
||||
"direction": "long",
|
||||
"result": "止盈",
|
||||
"pnl_amount": 12.5,
|
||||
"opened_at": "2026-01-01 10:00:00",
|
||||
"closed_at": "2026-01-01 12:00:00",
|
||||
"opened_at_ms": 1_700_000_000_000,
|
||||
"closed_at_ms": 1_700_007_200_000,
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"symbol": "ONDO",
|
||||
"direction": "short",
|
||||
"result": "止损",
|
||||
"pnl_amount": -3.2,
|
||||
"opened_at": "2026-01-02 10:00:00",
|
||||
"closed_at": "2026-01-02 11:00:00",
|
||||
"opened_at_ms": 1_700_086_400_000,
|
||||
"closed_at_ms": 1_700_090_000_000,
|
||||
},
|
||||
],
|
||||
db_path=db,
|
||||
)
|
||||
upsert_trade_overlay("gate", 2, behavior_tag="sick", note="追高", db_path=db)
|
||||
rows = list_symbol_rows(db_path=db)
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["trade_count"] == 2
|
||||
sick_only = list_symbol_rows(filter_sick=True, db_path=db)
|
||||
assert len(sick_only) == 1
|
||||
profit_only = list_symbol_rows(filter_profit=True, db_path=db)
|
||||
assert len(profit_only) == 1
|
||||
|
||||
|
||||
def test_parse_wall_clock_ms_uses_utc_plus_8():
|
||||
ms = parse_wall_clock_ms("2026-06-07 20:30:00")
|
||||
assert ms is not None
|
||||
dt_utc = datetime.fromtimestamp(ms / 1000.0, tz=timezone.utc)
|
||||
dt_bj = dt_utc.astimezone(CHART_DISPLAY_TZ)
|
||||
assert dt_bj.strftime("%Y-%m-%d %H:%M:%S") == "2026-06-07 20:30:00"
|
||||
assert ms_to_wall_clock_str(ms) == "2026-06-07 20:30:00"
|
||||
assert parse_wall_clock_ms("2026-06-07 20:30") == ms
|
||||
|
||||
|
||||
def test_parse_wall_clock_ms_accepts_epoch_strings():
|
||||
ms = 1_700_000_000_000
|
||||
assert parse_wall_clock_ms(str(ms)) == ms
|
||||
assert parse_wall_clock_ms(str(ms // 1000)) == ms
|
||||
|
||||
|
||||
def test_resolve_archive_chart_history_uses_trade_span_not_200_bars():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "archive.db"
|
||||
init_db(db)
|
||||
opened = 1_700_000_000_000
|
||||
closed = opened + 20 * 24 * 3600_000
|
||||
_seed_5m_bars(db, opened - 35 * 24 * 3600_000, 40 * 24 * 12)
|
||||
out = resolve_archive_chart(
|
||||
"gate",
|
||||
"ONDO",
|
||||
"15m",
|
||||
opened_ms=opened,
|
||||
closed_ms=closed,
|
||||
mode="hold",
|
||||
bars=200,
|
||||
range_mode="history",
|
||||
db_path=db,
|
||||
)
|
||||
assert out["ok"] is True
|
||||
assert out["range_mode"] == "history"
|
||||
assert out["bar_count"] > 200
|
||||
|
||||
|
||||
def test_upsert_forces_sync_exchange_key():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "archive.db"
|
||||
init_db(db)
|
||||
upsert_trades_cache(
|
||||
"gate_bot",
|
||||
[
|
||||
{
|
||||
"id": 77,
|
||||
"exchange_key": "gate",
|
||||
"account_exchange_key": "gate",
|
||||
"symbol": "ETH/USDT",
|
||||
"result": "止损",
|
||||
"pnl_amount": -1,
|
||||
"opened_at_ms": 1_700_000_000_000,
|
||||
"closed_at_ms": 1_700_007_200_000,
|
||||
}
|
||||
],
|
||||
db_path=db,
|
||||
)
|
||||
rows = load_symbol_trades("gate_bot", "ETH/USDT", db_path=db)
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["exchange_key"] == "gate_bot"
|
||||
assert "account_exchange_key" not in rows[0]
|
||||
|
||||
|
||||
def test_compute_period_stats_win_loss_metrics():
|
||||
rows = [
|
||||
{"exchange_key": "binance", "pnl_amount": 10.0, "behavior_tag": ""},
|
||||
{"exchange_key": "binance", "pnl_amount": 4.0, "behavior_tag": ""},
|
||||
{"exchange_key": "okx", "pnl_amount": -3.0, "behavior_tag": "sick"},
|
||||
{"exchange_key": "okx", "pnl_amount": -6.0, "behavior_tag": ""},
|
||||
]
|
||||
st = _compute_period_stats(rows)
|
||||
assert st["open_count"] == 4
|
||||
assert st["win_count"] == 2
|
||||
assert st["loss_count"] == 2
|
||||
assert st["avg_win"] == 7.0
|
||||
assert st["avg_loss"] == -4.5
|
||||
assert st["max_win"] == 10.0
|
||||
assert st["max_loss"] == -6.0
|
||||
assert st["win_rate"] == 50.0
|
||||
assert st["profit_loss_ratio"] == round(7.0 / 4.5, 2)
|
||||
assert st["sick_count"] == 1
|
||||
assert st["pnl_total"] == 5.0
|
||||
assert st["pnl_ex_sick"] == 8.0
|
||||
assert st["by_exchange"]["binance"]["win_count"] == 2
|
||||
assert st["by_exchange"]["binance"]["win_rate"] == 100.0
|
||||
assert st["by_exchange"]["binance"]["profit_loss_ratio"] is None
|
||||
|
||||
|
||||
def test_list_daily_trades_search_filters_stats():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "archive.db"
|
||||
init_db(db)
|
||||
day = "2023-11-15"
|
||||
start_ms, _ = trading_day_bounds_ms(day)
|
||||
btc_close = start_ms + 3_600_000
|
||||
eth_close = start_ms + 7_200_000
|
||||
upsert_trades_cache(
|
||||
"gate",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"symbol": "BTC/USDT",
|
||||
"result": "止盈",
|
||||
"pnl_amount": 5.0,
|
||||
"opened_at_ms": start_ms,
|
||||
"closed_at_ms": btc_close,
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"symbol": "ETH/USDT",
|
||||
"result": "止损",
|
||||
"pnl_amount": -2.0,
|
||||
"opened_at_ms": btc_close,
|
||||
"closed_at_ms": eth_close,
|
||||
},
|
||||
],
|
||||
db_path=db,
|
||||
)
|
||||
payload = list_daily_trades(
|
||||
period="range",
|
||||
date_from=day,
|
||||
date_to=day,
|
||||
search="btc",
|
||||
db_path=db,
|
||||
)
|
||||
assert len(payload["trades"]) == 1
|
||||
assert payload["trades"][0]["symbol"] == "BTC/USDT"
|
||||
st = payload["stats"]
|
||||
assert st["open_count"] == 1
|
||||
assert st["win_count"] == 1
|
||||
assert st["loss_count"] == 0
|
||||
assert st["max_win"] == 5.0
|
||||
assert st["pnl_total"] == 5.0
|
||||
"""币种档案库:5m 聚合与视窗计算。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from lib.hub.hub_ohlcv_lib import aggregate_ohlcv_bars
|
||||
from datetime import datetime, timezone
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from lib.hub.hub_symbol_archive_lib import (
|
||||
CHART_DISPLAY_TZ,
|
||||
_compute_period_stats,
|
||||
_fill_missing_bars,
|
||||
init_db,
|
||||
list_daily_trades,
|
||||
load_symbol_trades,
|
||||
ms_to_wall_clock_str,
|
||||
parse_wall_clock_ms,
|
||||
resolve_archive_chart,
|
||||
trading_day_bounds_ms,
|
||||
upsert_bars_5m,
|
||||
upsert_trade_overlay,
|
||||
list_symbol_rows,
|
||||
upsert_trades_cache,
|
||||
)
|
||||
|
||||
|
||||
def _seed_5m_bars(
|
||||
db: Path,
|
||||
start_ms: int,
|
||||
count: int,
|
||||
step: int = 300_000,
|
||||
*,
|
||||
ex: str = "gate",
|
||||
sym: str = "ONDO",
|
||||
) -> None:
|
||||
bars = []
|
||||
price = 1.0
|
||||
for i in range(count):
|
||||
o = start_ms + i * step
|
||||
price += 0.001
|
||||
bars.append(
|
||||
{
|
||||
"open_time_ms": o,
|
||||
"open": price,
|
||||
"high": price + 0.002,
|
||||
"low": price - 0.001,
|
||||
"close": price + 0.001,
|
||||
"volume": 100 + i,
|
||||
}
|
||||
)
|
||||
upsert_bars_5m(ex, sym, bars, db_path=db)
|
||||
|
||||
|
||||
def test_aggregate_15m_from_5m():
|
||||
start = 1_700_000_000_000
|
||||
bars = []
|
||||
for i in range(6):
|
||||
t = start + i * 300_000
|
||||
bars.append(
|
||||
{
|
||||
"open_time_ms": t,
|
||||
"open": 1.0,
|
||||
"high": 1.1,
|
||||
"low": 0.9,
|
||||
"close": 1.05,
|
||||
"volume": 10,
|
||||
}
|
||||
)
|
||||
agg = aggregate_ohlcv_bars(bars, "15m")
|
||||
assert len(agg) >= 1
|
||||
assert agg[-1]["close"] == bars[-1]["close"]
|
||||
assert agg[0]["open_time_ms"] <= agg[1]["open_time_ms"]
|
||||
|
||||
|
||||
def test_resolve_archive_chart_15m():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "archive.db"
|
||||
init_db(db)
|
||||
anchor = 1_700_000_000_000
|
||||
_seed_5m_bars(db, anchor - 50 * 300_000, 120)
|
||||
out = resolve_archive_chart(
|
||||
"gate",
|
||||
"ONDO",
|
||||
"15m",
|
||||
anchor_ms=anchor,
|
||||
mode="hold",
|
||||
bars=40,
|
||||
db_path=db,
|
||||
)
|
||||
assert out["ok"] is True
|
||||
assert out["timeframe"] == "15m"
|
||||
assert len(out["candles"]) >= 10
|
||||
|
||||
|
||||
def test_fill_missing_bars_continuity():
|
||||
period = 300_000
|
||||
start = (1_700_000_000_000 // period) * period
|
||||
bars = [
|
||||
{
|
||||
"open_time_ms": start,
|
||||
"open": 1.0,
|
||||
"high": 1.1,
|
||||
"low": 0.9,
|
||||
"close": 1.05,
|
||||
"volume": 10,
|
||||
},
|
||||
{
|
||||
"open_time_ms": start + period * 2,
|
||||
"open": 1.05,
|
||||
"high": 1.15,
|
||||
"low": 1.0,
|
||||
"close": 1.1,
|
||||
"volume": 8,
|
||||
},
|
||||
]
|
||||
filled = _fill_missing_bars(bars, period, start, start + period * 2)
|
||||
assert len(filled) >= 3
|
||||
assert any(b.get("filled") for b in filled)
|
||||
|
||||
|
||||
def test_resolve_archive_chart_history_range():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "archive.db"
|
||||
init_db(db)
|
||||
open_ms = 1_700_000_000_000
|
||||
close_ms = open_ms + 6 * 3600_000
|
||||
_seed_5m_bars(db, open_ms - 20 * 300_000, 200, ex="gate", sym="BNB/USDT")
|
||||
out = resolve_archive_chart(
|
||||
"gate",
|
||||
"BNB/USDT",
|
||||
"15m",
|
||||
opened_ms=open_ms,
|
||||
closed_ms=close_ms,
|
||||
mode="hold",
|
||||
range_mode="history",
|
||||
db_path=db,
|
||||
)
|
||||
assert out["ok"] is True
|
||||
assert out.get("range_mode") == "history"
|
||||
assert out.get("window_end_ms") <= close_ms + 4 * 3600_000
|
||||
assert len(out["candles"]) >= 40
|
||||
|
||||
|
||||
def test_sync_prunes_missing_trades():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "archive.db"
|
||||
init_db(db)
|
||||
upsert_trades_cache(
|
||||
"gate",
|
||||
[
|
||||
{"id": 1, "symbol": "BNB/USDT", "result": "止损", "pnl_amount": -1},
|
||||
{"id": 2, "symbol": "BNB/USDT", "result": "止盈", "pnl_amount": 1},
|
||||
],
|
||||
db_path=db,
|
||||
prune_missing=False,
|
||||
)
|
||||
stats = upsert_trades_cache(
|
||||
"gate",
|
||||
[{"id": 1, "symbol": "BNB/USDT", "result": "止损", "pnl_amount": -1}],
|
||||
db_path=db,
|
||||
prune_missing=True,
|
||||
)
|
||||
rows = load_symbol_trades("gate", "BNB/USDT", db_path=db)
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["trade_id"] == 1
|
||||
assert stats["removed"] == 1
|
||||
|
||||
|
||||
def test_list_with_overlay_filters():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "archive.db"
|
||||
init_db(db)
|
||||
upsert_trades_cache(
|
||||
"gate",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"symbol": "ONDO",
|
||||
"direction": "long",
|
||||
"result": "止盈",
|
||||
"pnl_amount": 12.5,
|
||||
"opened_at": "2026-01-01 10:00:00",
|
||||
"closed_at": "2026-01-01 12:00:00",
|
||||
"opened_at_ms": 1_700_000_000_000,
|
||||
"closed_at_ms": 1_700_007_200_000,
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"symbol": "ONDO",
|
||||
"direction": "short",
|
||||
"result": "止损",
|
||||
"pnl_amount": -3.2,
|
||||
"opened_at": "2026-01-02 10:00:00",
|
||||
"closed_at": "2026-01-02 11:00:00",
|
||||
"opened_at_ms": 1_700_086_400_000,
|
||||
"closed_at_ms": 1_700_090_000_000,
|
||||
},
|
||||
],
|
||||
db_path=db,
|
||||
)
|
||||
upsert_trade_overlay("gate", 2, behavior_tag="sick", note="追高", db_path=db)
|
||||
rows = list_symbol_rows(db_path=db)
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["trade_count"] == 2
|
||||
sick_only = list_symbol_rows(filter_sick=True, db_path=db)
|
||||
assert len(sick_only) == 1
|
||||
profit_only = list_symbol_rows(filter_profit=True, db_path=db)
|
||||
assert len(profit_only) == 1
|
||||
|
||||
|
||||
def test_parse_wall_clock_ms_uses_utc_plus_8():
|
||||
ms = parse_wall_clock_ms("2026-06-07 20:30:00")
|
||||
assert ms is not None
|
||||
dt_utc = datetime.fromtimestamp(ms / 1000.0, tz=timezone.utc)
|
||||
dt_bj = dt_utc.astimezone(CHART_DISPLAY_TZ)
|
||||
assert dt_bj.strftime("%Y-%m-%d %H:%M:%S") == "2026-06-07 20:30:00"
|
||||
assert ms_to_wall_clock_str(ms) == "2026-06-07 20:30:00"
|
||||
assert parse_wall_clock_ms("2026-06-07 20:30") == ms
|
||||
|
||||
|
||||
def test_parse_wall_clock_ms_accepts_epoch_strings():
|
||||
ms = 1_700_000_000_000
|
||||
assert parse_wall_clock_ms(str(ms)) == ms
|
||||
assert parse_wall_clock_ms(str(ms // 1000)) == ms
|
||||
|
||||
|
||||
def test_resolve_archive_chart_history_uses_trade_span_not_200_bars():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "archive.db"
|
||||
init_db(db)
|
||||
opened = 1_700_000_000_000
|
||||
closed = opened + 20 * 24 * 3600_000
|
||||
_seed_5m_bars(db, opened - 35 * 24 * 3600_000, 40 * 24 * 12)
|
||||
out = resolve_archive_chart(
|
||||
"gate",
|
||||
"ONDO",
|
||||
"15m",
|
||||
opened_ms=opened,
|
||||
closed_ms=closed,
|
||||
mode="hold",
|
||||
bars=200,
|
||||
range_mode="history",
|
||||
db_path=db,
|
||||
)
|
||||
assert out["ok"] is True
|
||||
assert out["range_mode"] == "history"
|
||||
assert out["bar_count"] > 200
|
||||
|
||||
|
||||
def test_upsert_forces_sync_exchange_key():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "archive.db"
|
||||
init_db(db)
|
||||
upsert_trades_cache(
|
||||
"gate_bot",
|
||||
[
|
||||
{
|
||||
"id": 77,
|
||||
"exchange_key": "gate",
|
||||
"account_exchange_key": "gate",
|
||||
"symbol": "ETH/USDT",
|
||||
"result": "止损",
|
||||
"pnl_amount": -1,
|
||||
"opened_at_ms": 1_700_000_000_000,
|
||||
"closed_at_ms": 1_700_007_200_000,
|
||||
}
|
||||
],
|
||||
db_path=db,
|
||||
)
|
||||
rows = load_symbol_trades("gate_bot", "ETH/USDT", db_path=db)
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["exchange_key"] == "gate_bot"
|
||||
assert "account_exchange_key" not in rows[0]
|
||||
|
||||
|
||||
def test_compute_period_stats_win_loss_metrics():
|
||||
rows = [
|
||||
{"exchange_key": "binance", "pnl_amount": 10.0, "behavior_tag": ""},
|
||||
{"exchange_key": "binance", "pnl_amount": 4.0, "behavior_tag": ""},
|
||||
{"exchange_key": "okx", "pnl_amount": -3.0, "behavior_tag": "sick"},
|
||||
{"exchange_key": "okx", "pnl_amount": -6.0, "behavior_tag": ""},
|
||||
]
|
||||
st = _compute_period_stats(rows)
|
||||
assert st["open_count"] == 4
|
||||
assert st["win_count"] == 2
|
||||
assert st["loss_count"] == 2
|
||||
assert st["avg_win"] == 7.0
|
||||
assert st["avg_loss"] == -4.5
|
||||
assert st["max_win"] == 10.0
|
||||
assert st["max_loss"] == -6.0
|
||||
assert st["win_rate"] == 50.0
|
||||
assert st["profit_loss_ratio"] == round(7.0 / 4.5, 2)
|
||||
assert st["sick_count"] == 1
|
||||
assert st["pnl_total"] == 5.0
|
||||
assert st["pnl_ex_sick"] == 8.0
|
||||
assert st["by_exchange"]["binance"]["win_count"] == 2
|
||||
assert st["by_exchange"]["binance"]["win_rate"] == 100.0
|
||||
assert st["by_exchange"]["binance"]["profit_loss_ratio"] is None
|
||||
|
||||
|
||||
def test_list_daily_trades_search_filters_stats():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "archive.db"
|
||||
init_db(db)
|
||||
day = "2023-11-15"
|
||||
start_ms, _ = trading_day_bounds_ms(day)
|
||||
btc_close = start_ms + 3_600_000
|
||||
eth_close = start_ms + 7_200_000
|
||||
upsert_trades_cache(
|
||||
"gate",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"symbol": "BTC/USDT",
|
||||
"result": "止盈",
|
||||
"pnl_amount": 5.0,
|
||||
"opened_at_ms": start_ms,
|
||||
"closed_at_ms": btc_close,
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"symbol": "ETH/USDT",
|
||||
"result": "止损",
|
||||
"pnl_amount": -2.0,
|
||||
"opened_at_ms": btc_close,
|
||||
"closed_at_ms": eth_close,
|
||||
},
|
||||
],
|
||||
db_path=db,
|
||||
)
|
||||
payload = list_daily_trades(
|
||||
period="range",
|
||||
date_from=day,
|
||||
date_to=day,
|
||||
search="btc",
|
||||
db_path=db,
|
||||
)
|
||||
assert len(payload["trades"]) == 1
|
||||
assert payload["trades"][0]["symbol"] == "BTC/USDT"
|
||||
st = payload["stats"]
|
||||
assert st["open_count"] == 1
|
||||
assert st["win_count"] == 1
|
||||
assert st["loss_count"] == 0
|
||||
assert st["max_win"] == 5.0
|
||||
assert st["pnl_total"] == 5.0
|
||||
|
||||
@@ -1,102 +1,102 @@
|
||||
"""档案交易:strategy_trade_snapshots 补全 gate_bot 漏记。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from hub_trades_lib import fetch_trades_for_archive
|
||||
|
||||
|
||||
def _init_db(path: Path) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(str(path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE trade_records (
|
||||
id INTEGER PRIMARY KEY,
|
||||
symbol TEXT,
|
||||
direction TEXT,
|
||||
result TEXT,
|
||||
pnl_amount REAL,
|
||||
opened_at TEXT,
|
||||
closed_at TEXT,
|
||||
opened_at_ms INTEGER,
|
||||
closed_at_ms INTEGER,
|
||||
created_at TEXT,
|
||||
trend_plan_id INTEGER
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE strategy_trade_snapshots (
|
||||
id INTEGER PRIMARY KEY,
|
||||
strategy_type TEXT,
|
||||
source_id INTEGER,
|
||||
symbol TEXT,
|
||||
direction TEXT,
|
||||
result_label TEXT,
|
||||
status_at_close TEXT,
|
||||
opened_at TEXT,
|
||||
closed_at TEXT,
|
||||
pnl_amount REAL,
|
||||
snapshot_json TEXT,
|
||||
created_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
return conn
|
||||
|
||||
|
||||
def test_merge_snapshot_when_trade_record_missing():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
conn = _init_db(Path(td) / "t.db")
|
||||
closed = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO strategy_trade_snapshots (
|
||||
id, strategy_type, source_id, symbol, direction,
|
||||
result_label, opened_at, closed_at, pnl_amount, snapshot_json, created_at
|
||||
) VALUES (?,?,?,?,?,?,?,?,?,?,?)
|
||||
""",
|
||||
(7, "trend_pullback", 42, "ONDO/USDT", "long", "止损", closed, closed, -1.2, "{}", closed),
|
||||
)
|
||||
conn.commit()
|
||||
trades = fetch_trades_for_archive(conn, days=30, limit=50)
|
||||
conn.close()
|
||||
assert len(trades) == 1
|
||||
assert trades[0]["symbol"] == "ONDO/USDT"
|
||||
assert trades[0]["id"] == -7
|
||||
assert trades[0].get("from_snapshot") is True
|
||||
|
||||
|
||||
def test_skip_snapshot_when_trade_record_exists():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
conn = _init_db(Path(td) / "t.db")
|
||||
closed = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO trade_records (
|
||||
id, symbol, direction, result, pnl_amount,
|
||||
opened_at, closed_at, opened_at_ms, closed_at_ms, created_at, trend_plan_id
|
||||
) VALUES (?,?,?,?,?,?,?,?,?,?,?)
|
||||
""",
|
||||
(1, "ONDO/USDT", "long", "止损", -1.2, closed, closed, 1, 2, closed, 42),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO strategy_trade_snapshots (
|
||||
id, strategy_type, source_id, symbol, direction,
|
||||
result_label, opened_at, closed_at, pnl_amount, snapshot_json, created_at
|
||||
) VALUES (?,?,?,?,?,?,?,?,?,?,?)
|
||||
""",
|
||||
(7, "trend_pullback", 42, "ONDO/USDT", "long", "止损", closed, closed, -1.2, "{}", closed),
|
||||
)
|
||||
conn.commit()
|
||||
trades = fetch_trades_for_archive(conn, days=30, limit=50)
|
||||
conn.close()
|
||||
assert len(trades) == 1
|
||||
assert trades[0]["id"] == 1
|
||||
"""档案交易:strategy_trade_snapshots 补全 gate_bot 漏记。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from lib.hub.hub_trades_lib import fetch_trades_for_archive
|
||||
|
||||
|
||||
def _init_db(path: Path) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(str(path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE trade_records (
|
||||
id INTEGER PRIMARY KEY,
|
||||
symbol TEXT,
|
||||
direction TEXT,
|
||||
result TEXT,
|
||||
pnl_amount REAL,
|
||||
opened_at TEXT,
|
||||
closed_at TEXT,
|
||||
opened_at_ms INTEGER,
|
||||
closed_at_ms INTEGER,
|
||||
created_at TEXT,
|
||||
trend_plan_id INTEGER
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE strategy_trade_snapshots (
|
||||
id INTEGER PRIMARY KEY,
|
||||
strategy_type TEXT,
|
||||
source_id INTEGER,
|
||||
symbol TEXT,
|
||||
direction TEXT,
|
||||
result_label TEXT,
|
||||
status_at_close TEXT,
|
||||
opened_at TEXT,
|
||||
closed_at TEXT,
|
||||
pnl_amount REAL,
|
||||
snapshot_json TEXT,
|
||||
created_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
return conn
|
||||
|
||||
|
||||
def test_merge_snapshot_when_trade_record_missing():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
conn = _init_db(Path(td) / "t.db")
|
||||
closed = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO strategy_trade_snapshots (
|
||||
id, strategy_type, source_id, symbol, direction,
|
||||
result_label, opened_at, closed_at, pnl_amount, snapshot_json, created_at
|
||||
) VALUES (?,?,?,?,?,?,?,?,?,?,?)
|
||||
""",
|
||||
(7, "trend_pullback", 42, "ONDO/USDT", "long", "止损", closed, closed, -1.2, "{}", closed),
|
||||
)
|
||||
conn.commit()
|
||||
trades = fetch_trades_for_archive(conn, days=30, limit=50)
|
||||
conn.close()
|
||||
assert len(trades) == 1
|
||||
assert trades[0]["symbol"] == "ONDO/USDT"
|
||||
assert trades[0]["id"] == -7
|
||||
assert trades[0].get("from_snapshot") is True
|
||||
|
||||
|
||||
def test_skip_snapshot_when_trade_record_exists():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
conn = _init_db(Path(td) / "t.db")
|
||||
closed = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO trade_records (
|
||||
id, symbol, direction, result, pnl_amount,
|
||||
opened_at, closed_at, opened_at_ms, closed_at_ms, created_at, trend_plan_id
|
||||
) VALUES (?,?,?,?,?,?,?,?,?,?,?)
|
||||
""",
|
||||
(1, "ONDO/USDT", "long", "止损", -1.2, closed, closed, 1, 2, closed, 42),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO strategy_trade_snapshots (
|
||||
id, strategy_type, source_id, symbol, direction,
|
||||
result_label, opened_at, closed_at, pnl_amount, snapshot_json, created_at
|
||||
) VALUES (?,?,?,?,?,?,?,?,?,?,?)
|
||||
""",
|
||||
(7, "trend_pullback", 42, "ONDO/USDT", "long", "止损", closed, closed, -1.2, "{}", closed),
|
||||
)
|
||||
conn.commit()
|
||||
trades = fetch_trades_for_archive(conn, days=30, limit=50)
|
||||
conn.close()
|
||||
assert len(trades) == 1
|
||||
assert trades[0]["id"] == 1
|
||||
|
||||
+198
-198
@@ -1,198 +1,198 @@
|
||||
"""hub_trades_lib 单元测试。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
|
||||
from hub_trades_lib import (
|
||||
fetch_trades_for_trading_day,
|
||||
summarize_trades,
|
||||
trading_day_from_dt,
|
||||
trading_day_window_bounds,
|
||||
)
|
||||
|
||||
|
||||
class HubTradesLibTest(unittest.TestCase):
|
||||
def test_trading_day_reset(self):
|
||||
dt = datetime(2026, 6, 6, 7, 30, 0)
|
||||
self.assertEqual(trading_day_from_dt(dt, 8), "2026-06-05")
|
||||
dt2 = datetime(2026, 6, 6, 8, 0, 0)
|
||||
self.assertEqual(trading_day_from_dt(dt2, 8), "2026-06-06")
|
||||
|
||||
def test_trading_day_window_bounds(self):
|
||||
start, end = trading_day_window_bounds("2026-06-06", 8)
|
||||
self.assertEqual(start, "2026-06-06 08:00:00")
|
||||
self.assertEqual(end, "2026-06-07 07:59:59")
|
||||
|
||||
def test_fetch_and_summarize(self):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute(
|
||||
"""CREATE TABLE trade_records (
|
||||
symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT,
|
||||
pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL,
|
||||
closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT,
|
||||
created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL,
|
||||
trade_style TEXT, entry_reason TEXT, reviewed_at TEXT
|
||||
)"""
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
|
||||
(
|
||||
"ONDO/USDT",
|
||||
"short",
|
||||
"止损",
|
||||
None,
|
||||
-0.5,
|
||||
None,
|
||||
None,
|
||||
"2026-06-06 10:00:00",
|
||||
None,
|
||||
"2026-06-06 09:00:00",
|
||||
None,
|
||||
"2026-06-06 10:00:00",
|
||||
"趋势回调",
|
||||
None,
|
||||
None,
|
||||
"trend",
|
||||
"",
|
||||
None,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
rows = fetch_trades_for_trading_day(conn, "2026-06-06")
|
||||
self.assertEqual(len(rows), 1)
|
||||
stats = summarize_trades(rows)
|
||||
self.assertEqual(stats["closed_count"], 1)
|
||||
self.assertEqual(stats["loss_count"], 1)
|
||||
self.assertAlmostEqual(stats["total_pnl_u"], -0.5)
|
||||
conn.close()
|
||||
|
||||
def test_early_morning_belongs_prev_trading_day(self):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute(
|
||||
"""CREATE TABLE trade_records (
|
||||
symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT,
|
||||
pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL,
|
||||
closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT,
|
||||
created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL,
|
||||
trade_style TEXT, entry_reason TEXT, reviewed_at TEXT
|
||||
)"""
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
|
||||
(
|
||||
"BTC/USDT",
|
||||
"long",
|
||||
"止盈",
|
||||
None,
|
||||
1.2,
|
||||
None,
|
||||
None,
|
||||
"2026-06-07 07:30:00",
|
||||
None,
|
||||
"2026-06-07 06:00:00",
|
||||
None,
|
||||
"2026-06-07 07:30:00",
|
||||
"关键位",
|
||||
None,
|
||||
None,
|
||||
"trend",
|
||||
"",
|
||||
None,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
self.assertEqual(len(fetch_trades_for_trading_day(conn, "2026-06-07")), 0)
|
||||
self.assertEqual(len(fetch_trades_for_trading_day(conn, "2026-06-06")), 1)
|
||||
conn.close()
|
||||
|
||||
def test_reviewed_fields_preferred(self):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute(
|
||||
"""CREATE TABLE trade_records (
|
||||
symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT,
|
||||
pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL,
|
||||
closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT,
|
||||
created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL,
|
||||
trade_style TEXT, entry_reason TEXT, reviewed_at TEXT
|
||||
)"""
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
|
||||
(
|
||||
"ETH/USDT",
|
||||
"long",
|
||||
"止损",
|
||||
"止盈",
|
||||
-0.5,
|
||||
2.0,
|
||||
None,
|
||||
"2026-06-06 09:00:00",
|
||||
"2026-06-06 11:00:00",
|
||||
"2026-06-06 08:00:00",
|
||||
None,
|
||||
"2026-06-06 11:00:00",
|
||||
"趋势回调",
|
||||
None,
|
||||
None,
|
||||
"trend",
|
||||
"",
|
||||
"2026-06-06 12:00:00",
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
rows = fetch_trades_for_trading_day(conn, "2026-06-06")
|
||||
self.assertEqual(len(rows), 1)
|
||||
self.assertEqual(rows[0]["result"], "止盈")
|
||||
self.assertAlmostEqual(rows[0]["pnl_amount"], 2.0)
|
||||
self.assertTrue(rows[0]["reviewed"])
|
||||
conn.close()
|
||||
|
||||
def test_time_close_result_included(self):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute(
|
||||
"""CREATE TABLE trade_records (
|
||||
symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT,
|
||||
pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL,
|
||||
closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT,
|
||||
created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL,
|
||||
trade_style TEXT, entry_reason TEXT, reviewed_at TEXT
|
||||
)"""
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
|
||||
(
|
||||
"BTC/USDT",
|
||||
"long",
|
||||
"时间平仓",
|
||||
None,
|
||||
1.2,
|
||||
None,
|
||||
None,
|
||||
"2026-06-06 12:00:00",
|
||||
None,
|
||||
"2026-06-06 08:00:00",
|
||||
None,
|
||||
"2026-06-06 12:00:00",
|
||||
"趋势回调",
|
||||
None,
|
||||
None,
|
||||
"trend",
|
||||
"",
|
||||
None,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
rows = fetch_trades_for_trading_day(conn, "2026-06-06")
|
||||
self.assertEqual(len(rows), 1)
|
||||
self.assertEqual(rows[0]["result"], "时间平仓")
|
||||
conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""hub_trades_lib 单元测试。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
|
||||
from lib.hub.hub_trades_lib import (
|
||||
fetch_trades_for_trading_day,
|
||||
summarize_trades,
|
||||
trading_day_from_dt,
|
||||
trading_day_window_bounds,
|
||||
)
|
||||
|
||||
|
||||
class HubTradesLibTest(unittest.TestCase):
|
||||
def test_trading_day_reset(self):
|
||||
dt = datetime(2026, 6, 6, 7, 30, 0)
|
||||
self.assertEqual(trading_day_from_dt(dt, 8), "2026-06-05")
|
||||
dt2 = datetime(2026, 6, 6, 8, 0, 0)
|
||||
self.assertEqual(trading_day_from_dt(dt2, 8), "2026-06-06")
|
||||
|
||||
def test_trading_day_window_bounds(self):
|
||||
start, end = trading_day_window_bounds("2026-06-06", 8)
|
||||
self.assertEqual(start, "2026-06-06 08:00:00")
|
||||
self.assertEqual(end, "2026-06-07 07:59:59")
|
||||
|
||||
def test_fetch_and_summarize(self):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute(
|
||||
"""CREATE TABLE trade_records (
|
||||
symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT,
|
||||
pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL,
|
||||
closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT,
|
||||
created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL,
|
||||
trade_style TEXT, entry_reason TEXT, reviewed_at TEXT
|
||||
)"""
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
|
||||
(
|
||||
"ONDO/USDT",
|
||||
"short",
|
||||
"止损",
|
||||
None,
|
||||
-0.5,
|
||||
None,
|
||||
None,
|
||||
"2026-06-06 10:00:00",
|
||||
None,
|
||||
"2026-06-06 09:00:00",
|
||||
None,
|
||||
"2026-06-06 10:00:00",
|
||||
"趋势回调",
|
||||
None,
|
||||
None,
|
||||
"trend",
|
||||
"",
|
||||
None,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
rows = fetch_trades_for_trading_day(conn, "2026-06-06")
|
||||
self.assertEqual(len(rows), 1)
|
||||
stats = summarize_trades(rows)
|
||||
self.assertEqual(stats["closed_count"], 1)
|
||||
self.assertEqual(stats["loss_count"], 1)
|
||||
self.assertAlmostEqual(stats["total_pnl_u"], -0.5)
|
||||
conn.close()
|
||||
|
||||
def test_early_morning_belongs_prev_trading_day(self):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute(
|
||||
"""CREATE TABLE trade_records (
|
||||
symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT,
|
||||
pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL,
|
||||
closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT,
|
||||
created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL,
|
||||
trade_style TEXT, entry_reason TEXT, reviewed_at TEXT
|
||||
)"""
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
|
||||
(
|
||||
"BTC/USDT",
|
||||
"long",
|
||||
"止盈",
|
||||
None,
|
||||
1.2,
|
||||
None,
|
||||
None,
|
||||
"2026-06-07 07:30:00",
|
||||
None,
|
||||
"2026-06-07 06:00:00",
|
||||
None,
|
||||
"2026-06-07 07:30:00",
|
||||
"关键位",
|
||||
None,
|
||||
None,
|
||||
"trend",
|
||||
"",
|
||||
None,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
self.assertEqual(len(fetch_trades_for_trading_day(conn, "2026-06-07")), 0)
|
||||
self.assertEqual(len(fetch_trades_for_trading_day(conn, "2026-06-06")), 1)
|
||||
conn.close()
|
||||
|
||||
def test_reviewed_fields_preferred(self):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute(
|
||||
"""CREATE TABLE trade_records (
|
||||
symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT,
|
||||
pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL,
|
||||
closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT,
|
||||
created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL,
|
||||
trade_style TEXT, entry_reason TEXT, reviewed_at TEXT
|
||||
)"""
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
|
||||
(
|
||||
"ETH/USDT",
|
||||
"long",
|
||||
"止损",
|
||||
"止盈",
|
||||
-0.5,
|
||||
2.0,
|
||||
None,
|
||||
"2026-06-06 09:00:00",
|
||||
"2026-06-06 11:00:00",
|
||||
"2026-06-06 08:00:00",
|
||||
None,
|
||||
"2026-06-06 11:00:00",
|
||||
"趋势回调",
|
||||
None,
|
||||
None,
|
||||
"trend",
|
||||
"",
|
||||
"2026-06-06 12:00:00",
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
rows = fetch_trades_for_trading_day(conn, "2026-06-06")
|
||||
self.assertEqual(len(rows), 1)
|
||||
self.assertEqual(rows[0]["result"], "止盈")
|
||||
self.assertAlmostEqual(rows[0]["pnl_amount"], 2.0)
|
||||
self.assertTrue(rows[0]["reviewed"])
|
||||
conn.close()
|
||||
|
||||
def test_time_close_result_included(self):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute(
|
||||
"""CREATE TABLE trade_records (
|
||||
symbol TEXT, direction TEXT, result TEXT, reviewed_result TEXT,
|
||||
pnl_amount REAL, reviewed_pnl_amount REAL, exchange_realized_pnl REAL,
|
||||
closed_at TEXT, reviewed_closed_at TEXT, opened_at TEXT, reviewed_opened_at TEXT,
|
||||
created_at TEXT, monitor_type TEXT, actual_rr REAL, planned_rr REAL,
|
||||
trade_style TEXT, entry_reason TEXT, reviewed_at TEXT
|
||||
)"""
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO trade_records VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
|
||||
(
|
||||
"BTC/USDT",
|
||||
"long",
|
||||
"时间平仓",
|
||||
None,
|
||||
1.2,
|
||||
None,
|
||||
None,
|
||||
"2026-06-06 12:00:00",
|
||||
None,
|
||||
"2026-06-06 08:00:00",
|
||||
None,
|
||||
"2026-06-06 12:00:00",
|
||||
"趋势回调",
|
||||
None,
|
||||
None,
|
||||
"trend",
|
||||
"",
|
||||
None,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
rows = fetch_trades_for_trading_day(conn, "2026-06-06")
|
||||
self.assertEqual(len(rows), 1)
|
||||
self.assertEqual(rows[0]["result"], "时间平仓")
|
||||
conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,115 +1,115 @@
|
||||
"""档案交易:复盘字段优先(开仓类型、持仓时长、开平仓时间)。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from hub_symbol_archive_lib import init_db, load_symbol_trades, upsert_trades_cache
|
||||
from hub_trades_lib import (
|
||||
_normalize_archive_trade_row,
|
||||
display_entry_type_label,
|
||||
effective_entry_type,
|
||||
effective_hold_minutes,
|
||||
)
|
||||
|
||||
|
||||
class TestHubTradesReviewFields(unittest.TestCase):
|
||||
def test_display_entry_type_for_manual_monitor_review(self):
|
||||
d = {
|
||||
"monitor_type": "下单监控",
|
||||
"entry_reason": "",
|
||||
"reviewed_entry_reason": "突破回踩",
|
||||
"reviewed_at": "2026-06-08 10:00:00",
|
||||
}
|
||||
self.assertEqual(display_entry_type_label(d), "突破回踩")
|
||||
|
||||
def test_effective_entry_type_prefers_reviewed(self):
|
||||
d = {
|
||||
"entry_reason": "突破回踩",
|
||||
"reviewed_entry_reason": "趋势回调",
|
||||
"monitor_type": "下单监控",
|
||||
}
|
||||
self.assertEqual(effective_entry_type(d), "趋势回调")
|
||||
|
||||
def test_effective_hold_minutes_prefers_reviewed(self):
|
||||
d = {
|
||||
"hold_minutes": 30,
|
||||
"reviewed_hold_minutes": 95,
|
||||
"opened_at_ms": 1_700_000_000_000,
|
||||
"closed_at_ms": 1_700_001_800_000,
|
||||
}
|
||||
self.assertEqual(effective_hold_minutes(d), 95)
|
||||
|
||||
def test_normalize_archive_trade_row_review_fields(self):
|
||||
closed = (datetime.now() - timedelta(days=2)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
opened = (datetime.now() - timedelta(days=2, hours=2)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
row = _normalize_archive_trade_row(
|
||||
{
|
||||
"id": 9,
|
||||
"symbol": "ONDO/USDT",
|
||||
"direction": "short",
|
||||
"result": "止损",
|
||||
"reviewed_result": "手动平仓",
|
||||
"pnl_amount": -2.5,
|
||||
"reviewed_pnl_amount": -2.58,
|
||||
"opened_at": opened,
|
||||
"reviewed_opened_at": "2026-06-07 14:30:00",
|
||||
"closed_at": closed,
|
||||
"reviewed_closed_at": "2026-06-08 08:44:21",
|
||||
"opened_at_ms": 1_700_000_000_000,
|
||||
"closed_at_ms": 1_700_007_200_000,
|
||||
"entry_reason": "突破回踩",
|
||||
"reviewed_entry_reason": "趋势回调",
|
||||
"hold_minutes": 30,
|
||||
"reviewed_hold_minutes": 1080,
|
||||
"monitor_type": "趋势回调",
|
||||
"reviewed_at": closed,
|
||||
},
|
||||
exchange_key="gate",
|
||||
)
|
||||
self.assertIsNotNone(row)
|
||||
assert row is not None
|
||||
self.assertEqual(row["entry_type"], "趋势回调")
|
||||
self.assertEqual(row["hold_minutes"], 1080)
|
||||
self.assertEqual(row["opened_at"], "2026-06-07 14:30:00")
|
||||
self.assertEqual(row["closed_at"], "2026-06-08 08:44:21")
|
||||
self.assertTrue(row["reviewed"])
|
||||
|
||||
def test_archive_cache_enriches_review_display_fields(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "archive.db"
|
||||
init_db(db)
|
||||
upsert_trades_cache(
|
||||
"gate",
|
||||
[
|
||||
{
|
||||
"id": 3,
|
||||
"symbol": "ONDO/USDT",
|
||||
"direction": "short",
|
||||
"result": "手动平仓",
|
||||
"pnl_amount": -2.58,
|
||||
"opened_at": "2026-06-07 14:30:00",
|
||||
"closed_at": "2026-06-08 08:44:21",
|
||||
"opened_at_ms": 1_781_000_000_000,
|
||||
"closed_at_ms": 1_781_065_000_000,
|
||||
"entry_type": "趋势回调",
|
||||
"hold_minutes": 1080,
|
||||
"hold_minutes_text": "18小时0分钟",
|
||||
"reviewed": True,
|
||||
}
|
||||
],
|
||||
db_path=db,
|
||||
)
|
||||
rows = load_symbol_trades("gate", "ONDO/USDT", db_path=db)
|
||||
self.assertEqual(len(rows), 1)
|
||||
self.assertEqual(rows[0]["entry_type"], "趋势回调")
|
||||
self.assertEqual(rows[0]["hold_minutes"], 1080)
|
||||
self.assertTrue(rows[0]["opened_at"].startswith("2026-06-07"))
|
||||
self.assertTrue(rows[0]["closed_at"].startswith("2026-06-08"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""档案交易:复盘字段优先(开仓类型、持仓时长、开平仓时间)。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from lib.hub.hub_symbol_archive_lib import init_db, load_symbol_trades, upsert_trades_cache
|
||||
from lib.hub.hub_trades_lib import (
|
||||
_normalize_archive_trade_row,
|
||||
display_entry_type_label,
|
||||
effective_entry_type,
|
||||
effective_hold_minutes,
|
||||
)
|
||||
|
||||
|
||||
class TestHubTradesReviewFields(unittest.TestCase):
|
||||
def test_display_entry_type_for_manual_monitor_review(self):
|
||||
d = {
|
||||
"monitor_type": "下单监控",
|
||||
"entry_reason": "",
|
||||
"reviewed_entry_reason": "突破回踩",
|
||||
"reviewed_at": "2026-06-08 10:00:00",
|
||||
}
|
||||
self.assertEqual(display_entry_type_label(d), "突破回踩")
|
||||
|
||||
def test_effective_entry_type_prefers_reviewed(self):
|
||||
d = {
|
||||
"entry_reason": "突破回踩",
|
||||
"reviewed_entry_reason": "趋势回调",
|
||||
"monitor_type": "下单监控",
|
||||
}
|
||||
self.assertEqual(effective_entry_type(d), "趋势回调")
|
||||
|
||||
def test_effective_hold_minutes_prefers_reviewed(self):
|
||||
d = {
|
||||
"hold_minutes": 30,
|
||||
"reviewed_hold_minutes": 95,
|
||||
"opened_at_ms": 1_700_000_000_000,
|
||||
"closed_at_ms": 1_700_001_800_000,
|
||||
}
|
||||
self.assertEqual(effective_hold_minutes(d), 95)
|
||||
|
||||
def test_normalize_archive_trade_row_review_fields(self):
|
||||
closed = (datetime.now() - timedelta(days=2)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
opened = (datetime.now() - timedelta(days=2, hours=2)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
row = _normalize_archive_trade_row(
|
||||
{
|
||||
"id": 9,
|
||||
"symbol": "ONDO/USDT",
|
||||
"direction": "short",
|
||||
"result": "止损",
|
||||
"reviewed_result": "手动平仓",
|
||||
"pnl_amount": -2.5,
|
||||
"reviewed_pnl_amount": -2.58,
|
||||
"opened_at": opened,
|
||||
"reviewed_opened_at": "2026-06-07 14:30:00",
|
||||
"closed_at": closed,
|
||||
"reviewed_closed_at": "2026-06-08 08:44:21",
|
||||
"opened_at_ms": 1_700_000_000_000,
|
||||
"closed_at_ms": 1_700_007_200_000,
|
||||
"entry_reason": "突破回踩",
|
||||
"reviewed_entry_reason": "趋势回调",
|
||||
"hold_minutes": 30,
|
||||
"reviewed_hold_minutes": 1080,
|
||||
"monitor_type": "趋势回调",
|
||||
"reviewed_at": closed,
|
||||
},
|
||||
exchange_key="gate",
|
||||
)
|
||||
self.assertIsNotNone(row)
|
||||
assert row is not None
|
||||
self.assertEqual(row["entry_type"], "趋势回调")
|
||||
self.assertEqual(row["hold_minutes"], 1080)
|
||||
self.assertEqual(row["opened_at"], "2026-06-07 14:30:00")
|
||||
self.assertEqual(row["closed_at"], "2026-06-08 08:44:21")
|
||||
self.assertTrue(row["reviewed"])
|
||||
|
||||
def test_archive_cache_enriches_review_display_fields(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = Path(td) / "archive.db"
|
||||
init_db(db)
|
||||
upsert_trades_cache(
|
||||
"gate",
|
||||
[
|
||||
{
|
||||
"id": 3,
|
||||
"symbol": "ONDO/USDT",
|
||||
"direction": "short",
|
||||
"result": "手动平仓",
|
||||
"pnl_amount": -2.58,
|
||||
"opened_at": "2026-06-07 14:30:00",
|
||||
"closed_at": "2026-06-08 08:44:21",
|
||||
"opened_at_ms": 1_781_000_000_000,
|
||||
"closed_at_ms": 1_781_065_000_000,
|
||||
"entry_type": "趋势回调",
|
||||
"hold_minutes": 1080,
|
||||
"hold_minutes_text": "18小时0分钟",
|
||||
"reviewed": True,
|
||||
}
|
||||
],
|
||||
db_path=db,
|
||||
)
|
||||
rows = load_symbol_trades("gate", "ONDO/USDT", db_path=db)
|
||||
self.assertEqual(len(rows), 1)
|
||||
self.assertEqual(rows[0]["entry_type"], "趋势回调")
|
||||
self.assertEqual(rows[0]["hold_minutes"], 1080)
|
||||
self.assertTrue(rows[0]["opened_at"].startswith("2026-06-07"))
|
||||
self.assertTrue(rows[0]["closed_at"].startswith("2026-06-08"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
+184
-184
@@ -1,184 +1,184 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from hub_volume_rank_lib import (
|
||||
CACHE_VERSION,
|
||||
LIQUIDITY_RANK_CACHE_VERSION,
|
||||
TOP_N_DEFAULT,
|
||||
_exchange_rank_row_stale,
|
||||
_okx_turnover_usdt,
|
||||
_scores_from_binance,
|
||||
_scores_from_gate,
|
||||
build_usdt_swap_volume_ranks,
|
||||
cache_needs_refresh,
|
||||
format_volume_quote,
|
||||
merge_exchange_rank,
|
||||
rank_date_label,
|
||||
resolve_daily_volume_rank,
|
||||
)
|
||||
|
||||
|
||||
def test_rank_date_label_after_reset():
|
||||
# 2026-06-08 09:00 北京时间 → 昨日交易日 2026-06-07
|
||||
dt = datetime(2026, 6, 8, 9, 0, 0)
|
||||
assert rank_date_label(now=dt, reset_hour=8) == "2026-06-07"
|
||||
|
||||
|
||||
def test_rank_date_label_before_reset():
|
||||
# 2026-06-08 07:00 → 当前交易日仍算 2026-06-07,昨日为 2026-06-06
|
||||
dt = datetime(2026, 6, 8, 7, 0, 0)
|
||||
assert rank_date_label(now=dt, reset_hour=8) == "2026-06-06"
|
||||
|
||||
|
||||
def test_format_volume_quote():
|
||||
assert format_volume_quote(1_500_000_000) == "1.50B"
|
||||
assert format_volume_quote(2_300_000) == "2.30M"
|
||||
assert format_volume_quote(4500) == "4.50K"
|
||||
|
||||
|
||||
def test_okx_turnover_usdt():
|
||||
qv = _okx_turnover_usdt({"volCcy24h": "100", "last": "50"})
|
||||
assert qv == 5000.0
|
||||
|
||||
|
||||
def test_cache_needs_refresh_and_merge():
|
||||
cache = {"rank_date": "2026-06-05", "exchanges": {}}
|
||||
assert cache_needs_refresh(cache, expected_rank_date="2026-06-07") is True
|
||||
merged = merge_exchange_rank(
|
||||
cache,
|
||||
"binance",
|
||||
{
|
||||
"ok": True,
|
||||
"rank_date": "2026-06-07",
|
||||
"items": [{"rank": 1, "symbol": "BTC/USDT", "volume_quote": 1.0}],
|
||||
"total_symbols": 100,
|
||||
},
|
||||
)
|
||||
assert merged["exchanges"]["binance"]["items"][0]["symbol"] == "BTC/USDT"
|
||||
assert merged["rank_date"] == "2026-06-07"
|
||||
|
||||
|
||||
def test_stale_cache_version_forces_refresh():
|
||||
cache = {"version": CACHE_VERSION - 1, "rank_date": "2026-06-07", "exchanges": {"okx": {"items": [{}]}}}
|
||||
assert cache_needs_refresh(cache) is True
|
||||
|
||||
|
||||
def test_short_item_list_is_stale():
|
||||
items = [{"rank": i, "symbol": f"S{i}/USDT"} for i in range(1, 13)]
|
||||
row = {"items": items, "total_symbols": 12}
|
||||
assert _exchange_rank_row_stale(row) is True
|
||||
full = {"items": items + [{"rank": i, "symbol": f"X{i}/USDT"} for i in range(13, TOP_N_DEFAULT + 1)], "total_symbols": 300}
|
||||
assert _exchange_rank_row_stale(full) is False
|
||||
|
||||
|
||||
def test_scores_from_binance_uses_fapi_lightweight_api():
|
||||
ex = MagicMock()
|
||||
ex.id = "binance"
|
||||
ex.fapiPublicGetTicker24hr.return_value = [
|
||||
{"symbol": "BTCUSDT", "quoteVolume": "9000000"},
|
||||
{"symbol": "ETHUSDT", "quoteVolume": "5000000"},
|
||||
]
|
||||
scored = _scores_from_binance(ex)
|
||||
assert scored[0][1] == "BTC"
|
||||
assert scored[0][2] == 9000000.0
|
||||
ex.fetch_tickers.assert_not_called()
|
||||
|
||||
|
||||
def test_scores_from_binance_skips_fetch_tickers_on_api_error():
|
||||
ex = MagicMock()
|
||||
ex.id = "binance"
|
||||
ex.fapiPublicGetTicker24hr.side_effect = RuntimeError("network")
|
||||
scored = _scores_from_binance(ex)
|
||||
assert scored == []
|
||||
ex.fetch_tickers.assert_not_called()
|
||||
|
||||
|
||||
def test_scores_from_gate_uses_futures_tickers_api():
|
||||
ex = MagicMock()
|
||||
ex.id = "gateio"
|
||||
ex.publicFuturesGetSettleTickers.return_value = [
|
||||
{"contract": "BTC_USDT", "volume_24h_quote": "8000000"},
|
||||
{"contract": "ETH_USDT", "volume_24h_quote": "4000000"},
|
||||
]
|
||||
scored = _scores_from_gate(ex)
|
||||
assert scored[0][1] == "BTC"
|
||||
ex.fetch_tickers.assert_not_called()
|
||||
|
||||
|
||||
def test_scores_from_gate_skips_fetch_tickers_on_api_error():
|
||||
ex = MagicMock()
|
||||
ex.id = "gateio"
|
||||
ex.publicFuturesGetSettleTickers.side_effect = RuntimeError("network")
|
||||
scored = _scores_from_gate(ex)
|
||||
assert scored == []
|
||||
ex.fetch_tickers.assert_not_called()
|
||||
|
||||
|
||||
def test_resolve_daily_volume_rank_caches_result():
|
||||
cache = {"version": 0, "updated_at": 0.0, "ranks": {}, "total": 0}
|
||||
ex = MagicMock()
|
||||
ex.id = "binance"
|
||||
ex.fapiPublicGetTicker24hr.return_value = [
|
||||
{"symbol": "BTCUSDT", "quoteVolume": "100"},
|
||||
{"symbol": "ETHUSDT", "quoteVolume": "50"},
|
||||
]
|
||||
|
||||
rank, total = resolve_daily_volume_rank(
|
||||
"BTC",
|
||||
cache,
|
||||
now_ts=1000.0,
|
||||
ttl_sec=60.0,
|
||||
exchange=ex,
|
||||
ensure_markets_loaded=lambda: None,
|
||||
)
|
||||
assert rank == 1
|
||||
assert total == 2
|
||||
assert cache["version"] == LIQUIDITY_RANK_CACHE_VERSION
|
||||
calls = ex.fapiPublicGetTicker24hr.call_count
|
||||
|
||||
rank2, _ = resolve_daily_volume_rank(
|
||||
"BTC",
|
||||
cache,
|
||||
now_ts=1010.0,
|
||||
ttl_sec=60.0,
|
||||
exchange=ex,
|
||||
ensure_markets_loaded=lambda: None,
|
||||
)
|
||||
assert rank2 == 1
|
||||
assert ex.fapiPublicGetTicker24hr.call_count == calls
|
||||
|
||||
|
||||
def test_resolve_daily_volume_rank_keeps_stale_cache_when_refresh_empty():
|
||||
cache = {
|
||||
"version": LIQUIDITY_RANK_CACHE_VERSION,
|
||||
"updated_at": 900.0,
|
||||
"ranks": {"BTC": 1},
|
||||
"total": 100,
|
||||
}
|
||||
ex = MagicMock()
|
||||
ex.id = "binance"
|
||||
ex.fapiPublicGetTicker24hr.return_value = []
|
||||
|
||||
rank, total = resolve_daily_volume_rank(
|
||||
"BTC",
|
||||
cache,
|
||||
now_ts=2000.0,
|
||||
ttl_sec=60.0,
|
||||
exchange=ex,
|
||||
ensure_markets_loaded=lambda: None,
|
||||
)
|
||||
assert rank == 1
|
||||
assert total == 100
|
||||
assert cache["updated_at"] == 900.0
|
||||
ex.fetch_tickers.assert_not_called()
|
||||
|
||||
|
||||
def test_build_usdt_swap_volume_ranks():
|
||||
ex = MagicMock()
|
||||
ex.id = "binance"
|
||||
ex.fapiPublicGetTicker24hr.return_value = [
|
||||
{"symbol": "SOLUSDT", "quoteVolume": "200"},
|
||||
]
|
||||
ranks, total = build_usdt_swap_volume_ranks(ex, lambda: None)
|
||||
assert ranks["SOL"] == 1
|
||||
assert total == 1
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from lib.hub.hub_volume_rank_lib import (
|
||||
CACHE_VERSION,
|
||||
LIQUIDITY_RANK_CACHE_VERSION,
|
||||
TOP_N_DEFAULT,
|
||||
_exchange_rank_row_stale,
|
||||
_okx_turnover_usdt,
|
||||
_scores_from_binance,
|
||||
_scores_from_gate,
|
||||
build_usdt_swap_volume_ranks,
|
||||
cache_needs_refresh,
|
||||
format_volume_quote,
|
||||
merge_exchange_rank,
|
||||
rank_date_label,
|
||||
resolve_daily_volume_rank,
|
||||
)
|
||||
|
||||
|
||||
def test_rank_date_label_after_reset():
|
||||
# 2026-06-08 09:00 北京时间 → 昨日交易日 2026-06-07
|
||||
dt = datetime(2026, 6, 8, 9, 0, 0)
|
||||
assert rank_date_label(now=dt, reset_hour=8) == "2026-06-07"
|
||||
|
||||
|
||||
def test_rank_date_label_before_reset():
|
||||
# 2026-06-08 07:00 → 当前交易日仍算 2026-06-07,昨日为 2026-06-06
|
||||
dt = datetime(2026, 6, 8, 7, 0, 0)
|
||||
assert rank_date_label(now=dt, reset_hour=8) == "2026-06-06"
|
||||
|
||||
|
||||
def test_format_volume_quote():
|
||||
assert format_volume_quote(1_500_000_000) == "1.50B"
|
||||
assert format_volume_quote(2_300_000) == "2.30M"
|
||||
assert format_volume_quote(4500) == "4.50K"
|
||||
|
||||
|
||||
def test_okx_turnover_usdt():
|
||||
qv = _okx_turnover_usdt({"volCcy24h": "100", "last": "50"})
|
||||
assert qv == 5000.0
|
||||
|
||||
|
||||
def test_cache_needs_refresh_and_merge():
|
||||
cache = {"rank_date": "2026-06-05", "exchanges": {}}
|
||||
assert cache_needs_refresh(cache, expected_rank_date="2026-06-07") is True
|
||||
merged = merge_exchange_rank(
|
||||
cache,
|
||||
"binance",
|
||||
{
|
||||
"ok": True,
|
||||
"rank_date": "2026-06-07",
|
||||
"items": [{"rank": 1, "symbol": "BTC/USDT", "volume_quote": 1.0}],
|
||||
"total_symbols": 100,
|
||||
},
|
||||
)
|
||||
assert merged["exchanges"]["binance"]["items"][0]["symbol"] == "BTC/USDT"
|
||||
assert merged["rank_date"] == "2026-06-07"
|
||||
|
||||
|
||||
def test_stale_cache_version_forces_refresh():
|
||||
cache = {"version": CACHE_VERSION - 1, "rank_date": "2026-06-07", "exchanges": {"okx": {"items": [{}]}}}
|
||||
assert cache_needs_refresh(cache) is True
|
||||
|
||||
|
||||
def test_short_item_list_is_stale():
|
||||
items = [{"rank": i, "symbol": f"S{i}/USDT"} for i in range(1, 13)]
|
||||
row = {"items": items, "total_symbols": 12}
|
||||
assert _exchange_rank_row_stale(row) is True
|
||||
full = {"items": items + [{"rank": i, "symbol": f"X{i}/USDT"} for i in range(13, TOP_N_DEFAULT + 1)], "total_symbols": 300}
|
||||
assert _exchange_rank_row_stale(full) is False
|
||||
|
||||
|
||||
def test_scores_from_binance_uses_fapi_lightweight_api():
|
||||
ex = MagicMock()
|
||||
ex.id = "binance"
|
||||
ex.fapiPublicGetTicker24hr.return_value = [
|
||||
{"symbol": "BTCUSDT", "quoteVolume": "9000000"},
|
||||
{"symbol": "ETHUSDT", "quoteVolume": "5000000"},
|
||||
]
|
||||
scored = _scores_from_binance(ex)
|
||||
assert scored[0][1] == "BTC"
|
||||
assert scored[0][2] == 9000000.0
|
||||
ex.fetch_tickers.assert_not_called()
|
||||
|
||||
|
||||
def test_scores_from_binance_skips_fetch_tickers_on_api_error():
|
||||
ex = MagicMock()
|
||||
ex.id = "binance"
|
||||
ex.fapiPublicGetTicker24hr.side_effect = RuntimeError("network")
|
||||
scored = _scores_from_binance(ex)
|
||||
assert scored == []
|
||||
ex.fetch_tickers.assert_not_called()
|
||||
|
||||
|
||||
def test_scores_from_gate_uses_futures_tickers_api():
|
||||
ex = MagicMock()
|
||||
ex.id = "gateio"
|
||||
ex.publicFuturesGetSettleTickers.return_value = [
|
||||
{"contract": "BTC_USDT", "volume_24h_quote": "8000000"},
|
||||
{"contract": "ETH_USDT", "volume_24h_quote": "4000000"},
|
||||
]
|
||||
scored = _scores_from_gate(ex)
|
||||
assert scored[0][1] == "BTC"
|
||||
ex.fetch_tickers.assert_not_called()
|
||||
|
||||
|
||||
def test_scores_from_gate_skips_fetch_tickers_on_api_error():
|
||||
ex = MagicMock()
|
||||
ex.id = "gateio"
|
||||
ex.publicFuturesGetSettleTickers.side_effect = RuntimeError("network")
|
||||
scored = _scores_from_gate(ex)
|
||||
assert scored == []
|
||||
ex.fetch_tickers.assert_not_called()
|
||||
|
||||
|
||||
def test_resolve_daily_volume_rank_caches_result():
|
||||
cache = {"version": 0, "updated_at": 0.0, "ranks": {}, "total": 0}
|
||||
ex = MagicMock()
|
||||
ex.id = "binance"
|
||||
ex.fapiPublicGetTicker24hr.return_value = [
|
||||
{"symbol": "BTCUSDT", "quoteVolume": "100"},
|
||||
{"symbol": "ETHUSDT", "quoteVolume": "50"},
|
||||
]
|
||||
|
||||
rank, total = resolve_daily_volume_rank(
|
||||
"BTC",
|
||||
cache,
|
||||
now_ts=1000.0,
|
||||
ttl_sec=60.0,
|
||||
exchange=ex,
|
||||
ensure_markets_loaded=lambda: None,
|
||||
)
|
||||
assert rank == 1
|
||||
assert total == 2
|
||||
assert cache["version"] == LIQUIDITY_RANK_CACHE_VERSION
|
||||
calls = ex.fapiPublicGetTicker24hr.call_count
|
||||
|
||||
rank2, _ = resolve_daily_volume_rank(
|
||||
"BTC",
|
||||
cache,
|
||||
now_ts=1010.0,
|
||||
ttl_sec=60.0,
|
||||
exchange=ex,
|
||||
ensure_markets_loaded=lambda: None,
|
||||
)
|
||||
assert rank2 == 1
|
||||
assert ex.fapiPublicGetTicker24hr.call_count == calls
|
||||
|
||||
|
||||
def test_resolve_daily_volume_rank_keeps_stale_cache_when_refresh_empty():
|
||||
cache = {
|
||||
"version": LIQUIDITY_RANK_CACHE_VERSION,
|
||||
"updated_at": 900.0,
|
||||
"ranks": {"BTC": 1},
|
||||
"total": 100,
|
||||
}
|
||||
ex = MagicMock()
|
||||
ex.id = "binance"
|
||||
ex.fapiPublicGetTicker24hr.return_value = []
|
||||
|
||||
rank, total = resolve_daily_volume_rank(
|
||||
"BTC",
|
||||
cache,
|
||||
now_ts=2000.0,
|
||||
ttl_sec=60.0,
|
||||
exchange=ex,
|
||||
ensure_markets_loaded=lambda: None,
|
||||
)
|
||||
assert rank == 1
|
||||
assert total == 100
|
||||
assert cache["updated_at"] == 900.0
|
||||
ex.fetch_tickers.assert_not_called()
|
||||
|
||||
|
||||
def test_build_usdt_swap_volume_ranks():
|
||||
ex = MagicMock()
|
||||
ex.id = "binance"
|
||||
ex.fapiPublicGetTicker24hr.return_value = [
|
||||
{"symbol": "SOLUSDT", "quoteVolume": "200"},
|
||||
]
|
||||
ranks, total = build_usdt_swap_volume_ranks(ex, lambda: None)
|
||||
assert ranks["SOL"] == 1
|
||||
assert total == 1
|
||||
|
||||
@@ -1,28 +1,28 @@
|
||||
from instance_embed_context_lib import embed_render_plan, trade_records_summary
|
||||
|
||||
|
||||
def test_embed_fragment_trade_is_light():
|
||||
plan = embed_render_plan("trade", "fragment")
|
||||
assert plan.exchange_capitals is False
|
||||
assert plan.records_rows is False
|
||||
assert plan.records_summary is False
|
||||
assert plan.orders is True
|
||||
assert plan.key_history is False
|
||||
|
||||
|
||||
def test_embed_shell_trade_summary_only():
|
||||
plan = embed_render_plan("trade", "shell")
|
||||
assert plan.exchange_capitals is True
|
||||
assert plan.records_summary is True
|
||||
assert plan.records_rows is False
|
||||
|
||||
|
||||
def test_embed_records_page_loads_rows():
|
||||
plan = embed_render_plan("records", "fragment")
|
||||
assert plan.records_rows is True
|
||||
|
||||
|
||||
def test_full_page_unchanged():
|
||||
plan = embed_render_plan("trade", None)
|
||||
assert plan.records_rows is True
|
||||
assert plan.exchange_capitals is True
|
||||
from lib.instance.instance_embed_context_lib import embed_render_plan, trade_records_summary
|
||||
|
||||
|
||||
def test_embed_fragment_trade_is_light():
|
||||
plan = embed_render_plan("trade", "fragment")
|
||||
assert plan.exchange_capitals is False
|
||||
assert plan.records_rows is False
|
||||
assert plan.records_summary is False
|
||||
assert plan.orders is True
|
||||
assert plan.key_history is False
|
||||
|
||||
|
||||
def test_embed_shell_trade_summary_only():
|
||||
plan = embed_render_plan("trade", "shell")
|
||||
assert plan.exchange_capitals is True
|
||||
assert plan.records_summary is True
|
||||
assert plan.records_rows is False
|
||||
|
||||
|
||||
def test_embed_records_page_loads_rows():
|
||||
plan = embed_render_plan("records", "fragment")
|
||||
assert plan.records_rows is True
|
||||
|
||||
|
||||
def test_full_page_unchanged():
|
||||
plan = embed_render_plan("trade", None)
|
||||
assert plan.records_rows is True
|
||||
assert plan.exchange_capitals is True
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
from instance_embed_lib import (
|
||||
EMBED_TABS,
|
||||
path_to_embed_tab,
|
||||
rewrite_embed_dest,
|
||||
)
|
||||
|
||||
|
||||
def test_path_to_embed_tab():
|
||||
assert path_to_embed_tab("/trade") == "trade"
|
||||
assert path_to_embed_tab("/key_monitor") == "key_monitor"
|
||||
assert path_to_embed_tab("/strategy/records") == "strategy_records"
|
||||
assert path_to_embed_tab("/unknown") is None
|
||||
|
||||
|
||||
def test_rewrite_embed_dest():
|
||||
url = rewrite_embed_dest("/trade", hub_theme="dark")
|
||||
assert url.startswith("/embed?")
|
||||
assert "tab=trade" in url
|
||||
assert "embed=1" in url
|
||||
assert "hub_theme=dark" in url
|
||||
|
||||
|
||||
def test_embed_tabs_cover_main_nav():
|
||||
assert "trade" in EMBED_TABS
|
||||
assert "key_monitor" in EMBED_TABS
|
||||
assert "records" in EMBED_TABS
|
||||
from lib.instance.instance_embed_lib import (
|
||||
EMBED_TABS,
|
||||
path_to_embed_tab,
|
||||
rewrite_embed_dest,
|
||||
)
|
||||
|
||||
|
||||
def test_path_to_embed_tab():
|
||||
assert path_to_embed_tab("/trade") == "trade"
|
||||
assert path_to_embed_tab("/key_monitor") == "key_monitor"
|
||||
assert path_to_embed_tab("/strategy/records") == "strategy_records"
|
||||
assert path_to_embed_tab("/unknown") is None
|
||||
|
||||
|
||||
def test_rewrite_embed_dest():
|
||||
url = rewrite_embed_dest("/trade", hub_theme="dark")
|
||||
assert url.startswith("/embed?")
|
||||
assert "tab=trade" in url
|
||||
assert "embed=1" in url
|
||||
assert "hub_theme=dark" in url
|
||||
|
||||
|
||||
def test_embed_tabs_cover_main_nav():
|
||||
assert "trade" in EMBED_TABS
|
||||
assert "key_monitor" in EMBED_TABS
|
||||
assert "records" in EMBED_TABS
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
from instance_nav_lib import request_is_hub_soft_nav
|
||||
|
||||
|
||||
def test_request_is_hub_soft_nav():
|
||||
class Req:
|
||||
args = {"embed": "1"}
|
||||
headers = {"X-Instance-Soft-Nav": "1"}
|
||||
|
||||
assert request_is_hub_soft_nav(Req()) is True
|
||||
|
||||
class Req2:
|
||||
args = {"embed": "1"}
|
||||
headers = {}
|
||||
|
||||
assert request_is_hub_soft_nav(Req2()) is False
|
||||
|
||||
class Req3:
|
||||
args = {}
|
||||
headers = {"X-Instance-Soft-Nav": "1"}
|
||||
|
||||
assert request_is_hub_soft_nav(Req3()) is False
|
||||
from lib.instance.instance_nav_lib import request_is_hub_soft_nav
|
||||
|
||||
|
||||
def test_request_is_hub_soft_nav():
|
||||
class Req:
|
||||
args = {"embed": "1"}
|
||||
headers = {"X-Instance-Soft-Nav": "1"}
|
||||
|
||||
assert request_is_hub_soft_nav(Req()) is True
|
||||
|
||||
class Req2:
|
||||
args = {"embed": "1"}
|
||||
headers = {}
|
||||
|
||||
assert request_is_hub_soft_nav(Req2()) is False
|
||||
|
||||
class Req3:
|
||||
args = {}
|
||||
headers = {"X-Instance-Soft-Nav": "1"}
|
||||
|
||||
assert request_is_hub_soft_nav(Req3()) is False
|
||||
|
||||
@@ -1,34 +1,34 @@
|
||||
import unittest
|
||||
|
||||
from key_monitor_lib import (
|
||||
BOX_BREAKOUT_CLOSE_OPPOSITE,
|
||||
box_breakout_invalidate_by_mark,
|
||||
box_breakout_invalidate_edge_label,
|
||||
)
|
||||
|
||||
|
||||
class BoxBreakoutInvalidateTests(unittest.TestCase):
|
||||
def test_short_invalidates_above_upper(self):
|
||||
self.assertTrue(box_breakout_invalidate_by_mark("short", 62.511, 61.746, 60.569))
|
||||
|
||||
def test_short_stays_valid_inside_or_below(self):
|
||||
self.assertFalse(box_breakout_invalidate_by_mark("short", 61.0, 61.746, 60.569))
|
||||
self.assertFalse(box_breakout_invalidate_by_mark("short", 60.0, 61.746, 60.569))
|
||||
|
||||
def test_long_invalidates_below_lower(self):
|
||||
self.assertTrue(box_breakout_invalidate_by_mark("long", 94.0, 100.0, 95.0))
|
||||
|
||||
def test_long_stays_valid_inside_or_above(self):
|
||||
self.assertFalse(box_breakout_invalidate_by_mark("long", 98.0, 100.0, 95.0))
|
||||
self.assertFalse(box_breakout_invalidate_by_mark("long", 101.0, 100.0, 95.0))
|
||||
|
||||
def test_edge_label(self):
|
||||
self.assertEqual(box_breakout_invalidate_edge_label("long"), "下沿")
|
||||
self.assertEqual(box_breakout_invalidate_edge_label("short"), "上沿")
|
||||
|
||||
def test_close_reason_constant(self):
|
||||
self.assertEqual(BOX_BREAKOUT_CLOSE_OPPOSITE, "box_opposite_break")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
import unittest
|
||||
|
||||
from lib.key_monitor.key_monitor_lib import (
|
||||
BOX_BREAKOUT_CLOSE_OPPOSITE,
|
||||
box_breakout_invalidate_by_mark,
|
||||
box_breakout_invalidate_edge_label,
|
||||
)
|
||||
|
||||
|
||||
class BoxBreakoutInvalidateTests(unittest.TestCase):
|
||||
def test_short_invalidates_above_upper(self):
|
||||
self.assertTrue(box_breakout_invalidate_by_mark("short", 62.511, 61.746, 60.569))
|
||||
|
||||
def test_short_stays_valid_inside_or_below(self):
|
||||
self.assertFalse(box_breakout_invalidate_by_mark("short", 61.0, 61.746, 60.569))
|
||||
self.assertFalse(box_breakout_invalidate_by_mark("short", 60.0, 61.746, 60.569))
|
||||
|
||||
def test_long_invalidates_below_lower(self):
|
||||
self.assertTrue(box_breakout_invalidate_by_mark("long", 94.0, 100.0, 95.0))
|
||||
|
||||
def test_long_stays_valid_inside_or_above(self):
|
||||
self.assertFalse(box_breakout_invalidate_by_mark("long", 98.0, 100.0, 95.0))
|
||||
self.assertFalse(box_breakout_invalidate_by_mark("long", 101.0, 100.0, 95.0))
|
||||
|
||||
def test_edge_label(self):
|
||||
self.assertEqual(box_breakout_invalidate_edge_label("long"), "下沿")
|
||||
self.assertEqual(box_breakout_invalidate_edge_label("short"), "上沿")
|
||||
|
||||
def test_close_reason_constant(self):
|
||||
self.assertEqual(BOX_BREAKOUT_CLOSE_OPPOSITE, "box_opposite_break")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,86 +1,86 @@
|
||||
"""阻力/支撑提醒:占位与间隔防重复推送。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from key_monitor_lib import (
|
||||
claim_rs_level_notify,
|
||||
notify_interval_elapsed,
|
||||
run_rs_level_alert_tick,
|
||||
)
|
||||
|
||||
|
||||
def _row(**kwargs):
|
||||
base = {
|
||||
"upper": 2.174,
|
||||
"lower": 1.694,
|
||||
"notification_count": 0,
|
||||
"max_notify": 3,
|
||||
"notify_interval_min": 5,
|
||||
"direction": "watch",
|
||||
"last_notified_at": None,
|
||||
"last_rs_bar_ts": None,
|
||||
}
|
||||
base.update(kwargs)
|
||||
return base
|
||||
|
||||
|
||||
class TestRsLevelAlertClaim(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.conn = sqlite3.connect(":memory:")
|
||||
self.conn.execute(
|
||||
"CREATE TABLE key_monitors ("
|
||||
"id INTEGER PRIMARY KEY, notification_count INTEGER DEFAULT 0, "
|
||||
"direction TEXT, last_notified_at TEXT, last_rs_bar_ts INTEGER)"
|
||||
)
|
||||
self.conn.execute(
|
||||
"INSERT INTO key_monitors (id, notification_count, direction) VALUES (1, 0, 'watch')"
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def test_claim_advances_once_per_index(self):
|
||||
ok1 = claim_rs_level_notify(
|
||||
self.conn, 1, 1, "long", "2026-06-02 00:25:00", 1000, prior_count=0
|
||||
)
|
||||
self.conn.commit()
|
||||
self.assertTrue(ok1)
|
||||
ok_dup = claim_rs_level_notify(
|
||||
self.conn, 1, 1, "long", "2026-06-02 00:25:03", 1000, prior_count=0
|
||||
)
|
||||
self.assertFalse(ok_dup)
|
||||
ok2 = claim_rs_level_notify(
|
||||
self.conn, 1, 2, "long", "2026-06-02 00:30:00", 1000, prior_count=1
|
||||
)
|
||||
self.conn.commit()
|
||||
self.assertTrue(ok2)
|
||||
row = self.conn.execute(
|
||||
"SELECT notification_count FROM key_monitors WHERE id=1"
|
||||
).fetchone()
|
||||
self.assertEqual(row[0], 2)
|
||||
|
||||
def test_second_push_requires_interval(self):
|
||||
now = datetime(2026, 6, 2, 0, 26, 0)
|
||||
row = _row(
|
||||
notification_count=1,
|
||||
direction="long",
|
||||
last_notified_at="2026-06-02 00:25:00",
|
||||
)
|
||||
tick = run_rs_level_alert_tick(row, 2.18, 1000, now, default_max_notify=3, default_interval_min=5)
|
||||
self.assertIsNone(tick)
|
||||
later = datetime(2026, 6, 2, 0, 30, 1)
|
||||
tick2 = run_rs_level_alert_tick(
|
||||
row, 2.18, 1000, later, default_max_notify=3, default_interval_min=5
|
||||
)
|
||||
self.assertIsNotNone(tick2)
|
||||
self.assertEqual(tick2["notify_index"], 2)
|
||||
self.assertEqual(tick2["prior_count"], 1)
|
||||
|
||||
def test_notify_interval_invalid_timestamp_does_not_spam(self):
|
||||
now = datetime(2026, 6, 2, 1, 0, 0)
|
||||
self.assertFalse(notify_interval_elapsed("not-a-date", 5, now))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""阻力/支撑提醒:占位与间隔防重复推送。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from lib.key_monitor.key_monitor_lib import (
|
||||
claim_rs_level_notify,
|
||||
notify_interval_elapsed,
|
||||
run_rs_level_alert_tick,
|
||||
)
|
||||
|
||||
|
||||
def _row(**kwargs):
|
||||
base = {
|
||||
"upper": 2.174,
|
||||
"lower": 1.694,
|
||||
"notification_count": 0,
|
||||
"max_notify": 3,
|
||||
"notify_interval_min": 5,
|
||||
"direction": "watch",
|
||||
"last_notified_at": None,
|
||||
"last_rs_bar_ts": None,
|
||||
}
|
||||
base.update(kwargs)
|
||||
return base
|
||||
|
||||
|
||||
class TestRsLevelAlertClaim(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.conn = sqlite3.connect(":memory:")
|
||||
self.conn.execute(
|
||||
"CREATE TABLE key_monitors ("
|
||||
"id INTEGER PRIMARY KEY, notification_count INTEGER DEFAULT 0, "
|
||||
"direction TEXT, last_notified_at TEXT, last_rs_bar_ts INTEGER)"
|
||||
)
|
||||
self.conn.execute(
|
||||
"INSERT INTO key_monitors (id, notification_count, direction) VALUES (1, 0, 'watch')"
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def test_claim_advances_once_per_index(self):
|
||||
ok1 = claim_rs_level_notify(
|
||||
self.conn, 1, 1, "long", "2026-06-02 00:25:00", 1000, prior_count=0
|
||||
)
|
||||
self.conn.commit()
|
||||
self.assertTrue(ok1)
|
||||
ok_dup = claim_rs_level_notify(
|
||||
self.conn, 1, 1, "long", "2026-06-02 00:25:03", 1000, prior_count=0
|
||||
)
|
||||
self.assertFalse(ok_dup)
|
||||
ok2 = claim_rs_level_notify(
|
||||
self.conn, 1, 2, "long", "2026-06-02 00:30:00", 1000, prior_count=1
|
||||
)
|
||||
self.conn.commit()
|
||||
self.assertTrue(ok2)
|
||||
row = self.conn.execute(
|
||||
"SELECT notification_count FROM key_monitors WHERE id=1"
|
||||
).fetchone()
|
||||
self.assertEqual(row[0], 2)
|
||||
|
||||
def test_second_push_requires_interval(self):
|
||||
now = datetime(2026, 6, 2, 0, 26, 0)
|
||||
row = _row(
|
||||
notification_count=1,
|
||||
direction="long",
|
||||
last_notified_at="2026-06-02 00:25:00",
|
||||
)
|
||||
tick = run_rs_level_alert_tick(row, 2.18, 1000, now, default_max_notify=3, default_interval_min=5)
|
||||
self.assertIsNone(tick)
|
||||
later = datetime(2026, 6, 2, 0, 30, 1)
|
||||
tick2 = run_rs_level_alert_tick(
|
||||
row, 2.18, 1000, later, default_max_notify=3, default_interval_min=5
|
||||
)
|
||||
self.assertIsNotNone(tick2)
|
||||
self.assertEqual(tick2["notify_index"], 2)
|
||||
self.assertEqual(tick2["prior_count"], 1)
|
||||
|
||||
def test_notify_interval_invalid_timestamp_does_not_spam(self):
|
||||
now = datetime(2026, 6, 2, 1, 0, 0)
|
||||
self.assertFalse(notify_interval_elapsed("not-a-date", 5, now))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,27 +1,27 @@
|
||||
import unittest
|
||||
|
||||
from key_monitor_lib import (
|
||||
KEY_MONITOR_RS_TYPE,
|
||||
is_rs_key_monitor_type,
|
||||
rs_monitor_type_for_storage,
|
||||
rs_monitor_type_label,
|
||||
)
|
||||
|
||||
|
||||
class KeyMonitorRsTypeTests(unittest.TestCase):
|
||||
def test_legacy_types_still_recognized(self):
|
||||
self.assertTrue(is_rs_key_monitor_type("关键阻力位"))
|
||||
self.assertTrue(is_rs_key_monitor_type("关键支撑位"))
|
||||
|
||||
def test_storage_normalizes_to_unified_type(self):
|
||||
self.assertEqual(rs_monitor_type_for_storage("关键阻力位"), KEY_MONITOR_RS_TYPE)
|
||||
self.assertEqual(rs_monitor_type_for_storage("关键支撑位"), KEY_MONITOR_RS_TYPE)
|
||||
self.assertEqual(rs_monitor_type_for_storage(KEY_MONITOR_RS_TYPE), KEY_MONITOR_RS_TYPE)
|
||||
|
||||
def test_label_merges_legacy_display(self):
|
||||
self.assertEqual(rs_monitor_type_label("关键阻力位"), KEY_MONITOR_RS_TYPE)
|
||||
self.assertEqual(rs_monitor_type_label("箱体突破"), "箱体突破")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
import unittest
|
||||
|
||||
from lib.key_monitor.key_monitor_lib import (
|
||||
KEY_MONITOR_RS_TYPE,
|
||||
is_rs_key_monitor_type,
|
||||
rs_monitor_type_for_storage,
|
||||
rs_monitor_type_label,
|
||||
)
|
||||
|
||||
|
||||
class KeyMonitorRsTypeTests(unittest.TestCase):
|
||||
def test_legacy_types_still_recognized(self):
|
||||
self.assertTrue(is_rs_key_monitor_type("关键阻力位"))
|
||||
self.assertTrue(is_rs_key_monitor_type("关键支撑位"))
|
||||
|
||||
def test_storage_normalizes_to_unified_type(self):
|
||||
self.assertEqual(rs_monitor_type_for_storage("关键阻力位"), KEY_MONITOR_RS_TYPE)
|
||||
self.assertEqual(rs_monitor_type_for_storage("关键支撑位"), KEY_MONITOR_RS_TYPE)
|
||||
self.assertEqual(rs_monitor_type_for_storage(KEY_MONITOR_RS_TYPE), KEY_MONITOR_RS_TYPE)
|
||||
|
||||
def test_label_merges_legacy_display(self):
|
||||
self.assertEqual(rs_monitor_type_label("关键阻力位"), KEY_MONITOR_RS_TYPE)
|
||||
self.assertEqual(rs_monitor_type_label("箱体突破"), "箱体突破")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,32 +1,32 @@
|
||||
from manual_sltp_lib import (
|
||||
MANUAL_FIXED_RR_DEFAULT,
|
||||
calc_tp_from_fixed_rr,
|
||||
parse_fixed_rr,
|
||||
resolve_open_sltp_prices,
|
||||
)
|
||||
|
||||
|
||||
def test_calc_tp_from_fixed_rr_long():
|
||||
tp = calc_tp_from_fixed_rr("long", 100.0, 95.0, 1.5)
|
||||
assert tp == 107.5
|
||||
|
||||
|
||||
def test_calc_tp_from_fixed_rr_short():
|
||||
tp = calc_tp_from_fixed_rr("short", 100.0, 105.0, 1.5)
|
||||
assert tp == 92.5
|
||||
|
||||
|
||||
def test_resolve_open_fixed_rr_mode():
|
||||
sl, tp = resolve_open_sltp_prices(
|
||||
"long",
|
||||
100.0,
|
||||
"fixed_rr",
|
||||
{"sl": "95", "fixed_rr": "1.5"},
|
||||
)
|
||||
assert sl == 95.0
|
||||
assert tp == 107.5
|
||||
|
||||
|
||||
def test_parse_fixed_rr_default():
|
||||
assert parse_fixed_rr(None) == MANUAL_FIXED_RR_DEFAULT
|
||||
assert parse_fixed_rr("2") == 2.0
|
||||
from lib.trade.manual_sltp_lib import (
|
||||
MANUAL_FIXED_RR_DEFAULT,
|
||||
calc_tp_from_fixed_rr,
|
||||
parse_fixed_rr,
|
||||
resolve_open_sltp_prices,
|
||||
)
|
||||
|
||||
|
||||
def test_calc_tp_from_fixed_rr_long():
|
||||
tp = calc_tp_from_fixed_rr("long", 100.0, 95.0, 1.5)
|
||||
assert tp == 107.5
|
||||
|
||||
|
||||
def test_calc_tp_from_fixed_rr_short():
|
||||
tp = calc_tp_from_fixed_rr("short", 100.0, 105.0, 1.5)
|
||||
assert tp == 92.5
|
||||
|
||||
|
||||
def test_resolve_open_fixed_rr_mode():
|
||||
sl, tp = resolve_open_sltp_prices(
|
||||
"long",
|
||||
100.0,
|
||||
"fixed_rr",
|
||||
{"sl": "95", "fixed_rr": "1.5"},
|
||||
)
|
||||
assert sl == 95.0
|
||||
assert tp == 107.5
|
||||
|
||||
|
||||
def test_parse_fixed_rr_default():
|
||||
assert parse_fixed_rr(None) == MANUAL_FIXED_RR_DEFAULT
|
||||
assert parse_fixed_rr("2") == 2.0
|
||||
|
||||
@@ -1,102 +1,102 @@
|
||||
from order_monitor_display_lib import (
|
||||
apply_order_price_display_fields,
|
||||
is_sl_breakeven_secured,
|
||||
monitor_open_stop_loss,
|
||||
order_monitor_tpsl_needs_sync,
|
||||
resolve_live_tpsl_prices,
|
||||
sl_breakeven_from_exchange_tpsl,
|
||||
snapshot_rr,
|
||||
snapshot_stop_loss,
|
||||
)
|
||||
|
||||
|
||||
def _calc_rr(direction, entry, sl, tp):
|
||||
if direction == "long":
|
||||
risk = entry - sl
|
||||
reward = tp - entry
|
||||
else:
|
||||
risk = sl - entry
|
||||
reward = entry - tp
|
||||
if risk <= 0 or reward <= 0:
|
||||
return None
|
||||
return round(reward / risk, 4)
|
||||
|
||||
|
||||
def test_snapshot_stop_loss_prefers_initial():
|
||||
assert snapshot_stop_loss(2.45, 2.6) == 2.45
|
||||
assert snapshot_stop_loss(None, 2.6) == 2.6
|
||||
|
||||
|
||||
def test_monitor_open_stop_loss_prefers_initial_snapshot():
|
||||
row = {"initial_stop_loss": 64000, "stop_loss": 63200}
|
||||
assert monitor_open_stop_loss(row) == 64000
|
||||
|
||||
|
||||
def test_snapshot_rr_ignores_current_stop_after_manual_move():
|
||||
rr = snapshot_rr(_calc_rr, "long", 2.726, 2.45, 2.65, 3.3)
|
||||
assert rr is not None
|
||||
assert rr > 2.0
|
||||
|
||||
|
||||
def test_breakeven_long():
|
||||
assert is_sl_breakeven_secured("long", 2.726, 2.726) is True
|
||||
assert is_sl_breakeven_secured("long", 2.726, 2.75) is True
|
||||
assert is_sl_breakeven_secured("long", 2.726, 2.45) is False
|
||||
|
||||
|
||||
def test_breakeven_short():
|
||||
assert is_sl_breakeven_secured("short", 72.73, 72.73) is True
|
||||
assert is_sl_breakeven_secured("short", 72.73, 72.0) is True
|
||||
assert is_sl_breakeven_secured("short", 72.73, 74.0) is False
|
||||
|
||||
|
||||
def test_sl_breakeven_from_exchange_tpsl():
|
||||
ok = sl_breakeven_from_exchange_tpsl(
|
||||
"long",
|
||||
2.726,
|
||||
{"sl": {"trigger_price": 2.735}, "tp": {"trigger_price": 3.3}},
|
||||
)
|
||||
assert ok is True
|
||||
|
||||
|
||||
def test_resolve_live_tpsl_prefers_exchange():
|
||||
disp_sl, disp_tp, ex_sl, ex_tp = resolve_live_tpsl_prices(
|
||||
1674,
|
||||
1647.65,
|
||||
{"sl": {"trigger_price": 1661}, "tp": {"trigger_price": 1647.65}},
|
||||
)
|
||||
assert disp_sl == 1661
|
||||
assert disp_tp == 1647.65
|
||||
assert ex_sl == 1661
|
||||
assert ex_tp == 1647.65
|
||||
|
||||
|
||||
def test_order_monitor_tpsl_needs_sync_detects_sl_change():
|
||||
new_sl, new_tp, changed = order_monitor_tpsl_needs_sync(
|
||||
1674,
|
||||
1647.65,
|
||||
{"sl": {"trigger_price": 1661}, "tp": {"trigger_price": 1647.65}},
|
||||
)
|
||||
assert changed is True
|
||||
assert new_sl == 1661
|
||||
assert new_tp == 1647.65
|
||||
|
||||
|
||||
def test_apply_order_price_display_fields_live_sl():
|
||||
payload = {}
|
||||
apply_order_price_display_fields(
|
||||
payload,
|
||||
direction="short",
|
||||
entry_price=1663.45,
|
||||
initial_stop_loss=1674,
|
||||
stop_loss=1674,
|
||||
take_profit=1647.65,
|
||||
calc_rr_ratio_fn=_calc_rr,
|
||||
exchange_tpsl={"sl": {"trigger_price": 1661}, "tp": {"trigger_price": 1647.65}},
|
||||
format_price_fn=lambda _s, v: f"{v:.2f}",
|
||||
symbol="ETH/USDT:USDT",
|
||||
)
|
||||
assert payload["stop_loss"] == 1661
|
||||
assert payload["stop_loss_display"] == "1661.00"
|
||||
assert payload["sl_breakeven_secured"] is True
|
||||
assert payload["rr_ratio"] is not None
|
||||
from lib.trade.order_monitor_display_lib import (
|
||||
apply_order_price_display_fields,
|
||||
is_sl_breakeven_secured,
|
||||
monitor_open_stop_loss,
|
||||
order_monitor_tpsl_needs_sync,
|
||||
resolve_live_tpsl_prices,
|
||||
sl_breakeven_from_exchange_tpsl,
|
||||
snapshot_rr,
|
||||
snapshot_stop_loss,
|
||||
)
|
||||
|
||||
|
||||
def _calc_rr(direction, entry, sl, tp):
|
||||
if direction == "long":
|
||||
risk = entry - sl
|
||||
reward = tp - entry
|
||||
else:
|
||||
risk = sl - entry
|
||||
reward = entry - tp
|
||||
if risk <= 0 or reward <= 0:
|
||||
return None
|
||||
return round(reward / risk, 4)
|
||||
|
||||
|
||||
def test_snapshot_stop_loss_prefers_initial():
|
||||
assert snapshot_stop_loss(2.45, 2.6) == 2.45
|
||||
assert snapshot_stop_loss(None, 2.6) == 2.6
|
||||
|
||||
|
||||
def test_monitor_open_stop_loss_prefers_initial_snapshot():
|
||||
row = {"initial_stop_loss": 64000, "stop_loss": 63200}
|
||||
assert monitor_open_stop_loss(row) == 64000
|
||||
|
||||
|
||||
def test_snapshot_rr_ignores_current_stop_after_manual_move():
|
||||
rr = snapshot_rr(_calc_rr, "long", 2.726, 2.45, 2.65, 3.3)
|
||||
assert rr is not None
|
||||
assert rr > 2.0
|
||||
|
||||
|
||||
def test_breakeven_long():
|
||||
assert is_sl_breakeven_secured("long", 2.726, 2.726) is True
|
||||
assert is_sl_breakeven_secured("long", 2.726, 2.75) is True
|
||||
assert is_sl_breakeven_secured("long", 2.726, 2.45) is False
|
||||
|
||||
|
||||
def test_breakeven_short():
|
||||
assert is_sl_breakeven_secured("short", 72.73, 72.73) is True
|
||||
assert is_sl_breakeven_secured("short", 72.73, 72.0) is True
|
||||
assert is_sl_breakeven_secured("short", 72.73, 74.0) is False
|
||||
|
||||
|
||||
def test_sl_breakeven_from_exchange_tpsl():
|
||||
ok = sl_breakeven_from_exchange_tpsl(
|
||||
"long",
|
||||
2.726,
|
||||
{"sl": {"trigger_price": 2.735}, "tp": {"trigger_price": 3.3}},
|
||||
)
|
||||
assert ok is True
|
||||
|
||||
|
||||
def test_resolve_live_tpsl_prefers_exchange():
|
||||
disp_sl, disp_tp, ex_sl, ex_tp = resolve_live_tpsl_prices(
|
||||
1674,
|
||||
1647.65,
|
||||
{"sl": {"trigger_price": 1661}, "tp": {"trigger_price": 1647.65}},
|
||||
)
|
||||
assert disp_sl == 1661
|
||||
assert disp_tp == 1647.65
|
||||
assert ex_sl == 1661
|
||||
assert ex_tp == 1647.65
|
||||
|
||||
|
||||
def test_order_monitor_tpsl_needs_sync_detects_sl_change():
|
||||
new_sl, new_tp, changed = order_monitor_tpsl_needs_sync(
|
||||
1674,
|
||||
1647.65,
|
||||
{"sl": {"trigger_price": 1661}, "tp": {"trigger_price": 1647.65}},
|
||||
)
|
||||
assert changed is True
|
||||
assert new_sl == 1661
|
||||
assert new_tp == 1647.65
|
||||
|
||||
|
||||
def test_apply_order_price_display_fields_live_sl():
|
||||
payload = {}
|
||||
apply_order_price_display_fields(
|
||||
payload,
|
||||
direction="short",
|
||||
entry_price=1663.45,
|
||||
initial_stop_loss=1674,
|
||||
stop_loss=1674,
|
||||
take_profit=1647.65,
|
||||
calc_rr_ratio_fn=_calc_rr,
|
||||
exchange_tpsl={"sl": {"trigger_price": 1661}, "tp": {"trigger_price": 1647.65}},
|
||||
format_price_fn=lambda _s, v: f"{v:.2f}",
|
||||
symbol="ETH/USDT:USDT",
|
||||
)
|
||||
assert payload["stop_loss"] == 1661
|
||||
assert payload["stop_loss_display"] == "1661.00"
|
||||
assert payload["sl_breakeven_secured"] is True
|
||||
assert payload["rr_ratio"] is not None
|
||||
|
||||
@@ -1,78 +1,78 @@
|
||||
import sqlite3
|
||||
import unittest
|
||||
|
||||
from strategy_db import init_strategy_tables
|
||||
from strategy_trade_labels import (
|
||||
MONITOR_TYPE_TREND_PULLBACK,
|
||||
count_position_limit_active_monitors,
|
||||
)
|
||||
|
||||
|
||||
def _mem_conn():
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute(
|
||||
"""CREATE TABLE order_monitors (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
symbol TEXT,
|
||||
direction TEXT,
|
||||
status TEXT,
|
||||
monitor_type TEXT,
|
||||
key_signal_type TEXT,
|
||||
trend_plan_id INTEGER
|
||||
)"""
|
||||
)
|
||||
init_strategy_tables(conn)
|
||||
return conn
|
||||
|
||||
|
||||
class PositionLimitCountTests(unittest.TestCase):
|
||||
def test_regular_monitor_counts(self):
|
||||
conn = _mem_conn()
|
||||
conn.execute(
|
||||
"INSERT INTO order_monitors (symbol, status, monitor_type) VALUES ('ETH/USDT', 'active', '下单监控')"
|
||||
)
|
||||
conn.commit()
|
||||
self.assertEqual(count_position_limit_active_monitors(conn), 1)
|
||||
|
||||
def test_trend_pullback_excluded(self):
|
||||
conn = _mem_conn()
|
||||
conn.execute(
|
||||
"""INSERT INTO order_monitors
|
||||
(symbol, status, monitor_type, trend_plan_id)
|
||||
VALUES ('ETH/USDT', 'active', ?, 12)""",
|
||||
(MONITOR_TYPE_TREND_PULLBACK,),
|
||||
)
|
||||
conn.commit()
|
||||
self.assertEqual(count_position_limit_active_monitors(conn), 0)
|
||||
|
||||
def test_active_roll_group_still_counts_regular_monitor(self):
|
||||
conn = _mem_conn()
|
||||
conn.execute(
|
||||
"INSERT INTO order_monitors (id, symbol, status, monitor_type) VALUES (1, 'ETH/USDT', 'active', '下单监控')"
|
||||
)
|
||||
conn.execute(
|
||||
"""INSERT INTO roll_groups
|
||||
(order_monitor_id, symbol, direction, status)
|
||||
VALUES (1, 'ETH/USDT', 'long', 'active')"""
|
||||
)
|
||||
conn.commit()
|
||||
self.assertEqual(count_position_limit_active_monitors(conn), 1)
|
||||
|
||||
def test_mixed_monitors(self):
|
||||
conn = _mem_conn()
|
||||
conn.execute(
|
||||
"INSERT INTO order_monitors (symbol, status, monitor_type) VALUES ('BTC/USDT', 'active', '下单监控')"
|
||||
)
|
||||
conn.execute(
|
||||
"""INSERT INTO order_monitors
|
||||
(symbol, status, monitor_type, trend_plan_id)
|
||||
VALUES ('ETH/USDT', 'active', ?, 3)""",
|
||||
(MONITOR_TYPE_TREND_PULLBACK,),
|
||||
)
|
||||
conn.commit()
|
||||
self.assertEqual(count_position_limit_active_monitors(conn), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
import sqlite3
|
||||
import unittest
|
||||
|
||||
from lib.strategy.strategy_db import init_strategy_tables
|
||||
from lib.strategy.strategy_trade_labels import (
|
||||
MONITOR_TYPE_TREND_PULLBACK,
|
||||
count_position_limit_active_monitors,
|
||||
)
|
||||
|
||||
|
||||
def _mem_conn():
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute(
|
||||
"""CREATE TABLE order_monitors (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
symbol TEXT,
|
||||
direction TEXT,
|
||||
status TEXT,
|
||||
monitor_type TEXT,
|
||||
key_signal_type TEXT,
|
||||
trend_plan_id INTEGER
|
||||
)"""
|
||||
)
|
||||
init_strategy_tables(conn)
|
||||
return conn
|
||||
|
||||
|
||||
class PositionLimitCountTests(unittest.TestCase):
|
||||
def test_regular_monitor_counts(self):
|
||||
conn = _mem_conn()
|
||||
conn.execute(
|
||||
"INSERT INTO order_monitors (symbol, status, monitor_type) VALUES ('ETH/USDT', 'active', '下单监控')"
|
||||
)
|
||||
conn.commit()
|
||||
self.assertEqual(count_position_limit_active_monitors(conn), 1)
|
||||
|
||||
def test_trend_pullback_excluded(self):
|
||||
conn = _mem_conn()
|
||||
conn.execute(
|
||||
"""INSERT INTO order_monitors
|
||||
(symbol, status, monitor_type, trend_plan_id)
|
||||
VALUES ('ETH/USDT', 'active', ?, 12)""",
|
||||
(MONITOR_TYPE_TREND_PULLBACK,),
|
||||
)
|
||||
conn.commit()
|
||||
self.assertEqual(count_position_limit_active_monitors(conn), 0)
|
||||
|
||||
def test_active_roll_group_still_counts_regular_monitor(self):
|
||||
conn = _mem_conn()
|
||||
conn.execute(
|
||||
"INSERT INTO order_monitors (id, symbol, status, monitor_type) VALUES (1, 'ETH/USDT', 'active', '下单监控')"
|
||||
)
|
||||
conn.execute(
|
||||
"""INSERT INTO roll_groups
|
||||
(order_monitor_id, symbol, direction, status)
|
||||
VALUES (1, 'ETH/USDT', 'long', 'active')"""
|
||||
)
|
||||
conn.commit()
|
||||
self.assertEqual(count_position_limit_active_monitors(conn), 1)
|
||||
|
||||
def test_mixed_monitors(self):
|
||||
conn = _mem_conn()
|
||||
conn.execute(
|
||||
"INSERT INTO order_monitors (symbol, status, monitor_type) VALUES ('BTC/USDT', 'active', '下单监控')"
|
||||
)
|
||||
conn.execute(
|
||||
"""INSERT INTO order_monitors
|
||||
(symbol, status, monitor_type, trend_plan_id)
|
||||
VALUES ('ETH/USDT', 'active', ?, 3)""",
|
||||
(MONITOR_TYPE_TREND_PULLBACK,),
|
||||
)
|
||||
conn.commit()
|
||||
self.assertEqual(count_position_limit_active_monitors(conn), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,34 +1,34 @@
|
||||
"""全仓 / 以损定仓 风险展示文案。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from position_sizing_lib import ( # noqa: E402
|
||||
format_risk_display_text,
|
||||
risk_percent_for_storage,
|
||||
)
|
||||
|
||||
|
||||
class TestPositionSizingRiskDisplay(unittest.TestCase):
|
||||
def test_full_margin_shows_amount_only(self):
|
||||
self.assertEqual(
|
||||
format_risk_display_text("full_margin", 1.0, 2.58, decimals=2),
|
||||
"2.58U",
|
||||
)
|
||||
self.assertIsNone(risk_percent_for_storage("full_margin", 1.0))
|
||||
|
||||
def test_risk_mode_shows_percent_and_amount(self):
|
||||
self.assertEqual(
|
||||
format_risk_display_text("risk", 2.0, 10.5, decimals=2),
|
||||
"2%≈10.5U",
|
||||
)
|
||||
self.assertEqual(risk_percent_for_storage("risk", 2.0), 2.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""全仓 / 以损定仓 风险展示文案。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from lib.trade.position_sizing_lib import ( # noqa: E402
|
||||
format_risk_display_text,
|
||||
risk_percent_for_storage,
|
||||
)
|
||||
|
||||
|
||||
class TestPositionSizingRiskDisplay(unittest.TestCase):
|
||||
def test_full_margin_shows_amount_only(self):
|
||||
self.assertEqual(
|
||||
format_risk_display_text("full_margin", 1.0, 2.58, decimals=2),
|
||||
"2.58U",
|
||||
)
|
||||
self.assertIsNone(risk_percent_for_storage("full_margin", 1.0))
|
||||
|
||||
def test_risk_mode_shows_percent_and_amount(self):
|
||||
self.assertEqual(
|
||||
format_risk_display_text("risk", 2.0, 10.5, decimals=2),
|
||||
"2%≈10.5U",
|
||||
)
|
||||
self.assertEqual(risk_percent_for_storage("risk", 2.0), 2.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
+112
-112
@@ -1,112 +1,112 @@
|
||||
from strategy_roll_lib import (
|
||||
preview_roll,
|
||||
roll_breakout_invalidate,
|
||||
roll_breakout_trigger_crossed,
|
||||
roll_fib_invalidate,
|
||||
roll_fib_trigger_crossed,
|
||||
solve_add_amount_for_total_risk,
|
||||
validate_roll_geometry,
|
||||
)
|
||||
|
||||
|
||||
def test_solve_add_amount_long_one_risk():
|
||||
q2, err = solve_add_amount_for_total_risk(
|
||||
"long", 1.0, 3000.0, 3100.0, 2950.0, 200.0, 1.0
|
||||
)
|
||||
assert err is None
|
||||
avg = (1 * 3000 + q2 * 3100) / (1 + q2)
|
||||
loss = (avg - 2950) * (1 + q2)
|
||||
assert abs(loss - 200.0) < 0.01
|
||||
|
||||
|
||||
def test_preview_roll_market_short():
|
||||
preview, err = preview_roll(
|
||||
direction="short",
|
||||
symbol="HYPE/USDT",
|
||||
qty_existing=3.0,
|
||||
entry_existing=65.0,
|
||||
initial_take_profit=60.0,
|
||||
add_mode="market",
|
||||
new_stop_loss=66.5,
|
||||
risk_percent=2.0,
|
||||
capital_base_usdt=1000.0,
|
||||
add_price=64.0,
|
||||
legs_done=1,
|
||||
)
|
||||
assert err is None
|
||||
assert preview["add_mode_label"] == "市价加仓"
|
||||
sl = preview["new_stop_loss"]
|
||||
avg = preview["avg_entry_after"]
|
||||
qty = preview["qty_after"]
|
||||
loss = (sl - avg) * qty
|
||||
assert abs(loss - 20.0) < 0.01
|
||||
|
||||
|
||||
def test_fib_cross_long_down():
|
||||
assert roll_fib_trigger_crossed("long", 101.0, 100.0, 100.5) is True
|
||||
assert roll_fib_trigger_crossed("long", 100.6, 100.6, 100.5) is False
|
||||
|
||||
|
||||
def test_breakout_cross_long_up():
|
||||
assert roll_breakout_trigger_crossed("long", 99.0, 100.5, 100.0) is True
|
||||
assert roll_breakout_invalidate("long", 98.0, 99.0) is True
|
||||
assert roll_fib_invalidate("long", 110.0, 105.0, 95.0) is True
|
||||
|
||||
|
||||
def test_preview_breakout_mode_label():
|
||||
preview, err = preview_roll(
|
||||
direction="long",
|
||||
symbol="ETH/USDT",
|
||||
qty_existing=1.0,
|
||||
entry_existing=3000.0,
|
||||
initial_take_profit=3500.0,
|
||||
add_mode="breakout",
|
||||
new_stop_loss=2980.0,
|
||||
breakthrough_price=3100.0,
|
||||
risk_percent=10.0,
|
||||
capital_base_usdt=1000.0,
|
||||
add_price=3050.0,
|
||||
)
|
||||
assert err is None
|
||||
assert preview["add_mode_label"] == "突破加仓"
|
||||
|
||||
|
||||
def test_breakout_geometry_short_mark_above_breakout():
|
||||
err = validate_roll_geometry(
|
||||
"short",
|
||||
"breakout",
|
||||
new_stop_loss=568.0,
|
||||
breakthrough_price=551.0,
|
||||
entry_existing=560.0,
|
||||
initial_take_profit=540.0,
|
||||
mark_price=560.0,
|
||||
)
|
||||
assert err is None
|
||||
|
||||
|
||||
def test_breakout_geometry_short_rejects_mark_at_or_below_breakout():
|
||||
err = validate_roll_geometry(
|
||||
"short",
|
||||
"breakout",
|
||||
new_stop_loss=568.0,
|
||||
breakthrough_price=551.0,
|
||||
entry_existing=560.0,
|
||||
initial_take_profit=540.0,
|
||||
mark_price=551.0,
|
||||
)
|
||||
assert err is not None
|
||||
assert "高于突破价" in err
|
||||
|
||||
|
||||
def test_breakout_geometry_long_rejects_mark_at_or_above_breakout():
|
||||
err = validate_roll_geometry(
|
||||
"long",
|
||||
"breakout",
|
||||
new_stop_loss=2980.0,
|
||||
breakthrough_price=3100.0,
|
||||
entry_existing=3000.0,
|
||||
initial_take_profit=3500.0,
|
||||
mark_price=3100.0,
|
||||
)
|
||||
assert err is not None
|
||||
assert "低于突破价" in err
|
||||
from lib.strategy.strategy_roll_lib import (
|
||||
preview_roll,
|
||||
roll_breakout_invalidate,
|
||||
roll_breakout_trigger_crossed,
|
||||
roll_fib_invalidate,
|
||||
roll_fib_trigger_crossed,
|
||||
solve_add_amount_for_total_risk,
|
||||
validate_roll_geometry,
|
||||
)
|
||||
|
||||
|
||||
def test_solve_add_amount_long_one_risk():
|
||||
q2, err = solve_add_amount_for_total_risk(
|
||||
"long", 1.0, 3000.0, 3100.0, 2950.0, 200.0, 1.0
|
||||
)
|
||||
assert err is None
|
||||
avg = (1 * 3000 + q2 * 3100) / (1 + q2)
|
||||
loss = (avg - 2950) * (1 + q2)
|
||||
assert abs(loss - 200.0) < 0.01
|
||||
|
||||
|
||||
def test_preview_roll_market_short():
|
||||
preview, err = preview_roll(
|
||||
direction="short",
|
||||
symbol="HYPE/USDT",
|
||||
qty_existing=3.0,
|
||||
entry_existing=65.0,
|
||||
initial_take_profit=60.0,
|
||||
add_mode="market",
|
||||
new_stop_loss=66.5,
|
||||
risk_percent=2.0,
|
||||
capital_base_usdt=1000.0,
|
||||
add_price=64.0,
|
||||
legs_done=1,
|
||||
)
|
||||
assert err is None
|
||||
assert preview["add_mode_label"] == "市价加仓"
|
||||
sl = preview["new_stop_loss"]
|
||||
avg = preview["avg_entry_after"]
|
||||
qty = preview["qty_after"]
|
||||
loss = (sl - avg) * qty
|
||||
assert abs(loss - 20.0) < 0.01
|
||||
|
||||
|
||||
def test_fib_cross_long_down():
|
||||
assert roll_fib_trigger_crossed("long", 101.0, 100.0, 100.5) is True
|
||||
assert roll_fib_trigger_crossed("long", 100.6, 100.6, 100.5) is False
|
||||
|
||||
|
||||
def test_breakout_cross_long_up():
|
||||
assert roll_breakout_trigger_crossed("long", 99.0, 100.5, 100.0) is True
|
||||
assert roll_breakout_invalidate("long", 98.0, 99.0) is True
|
||||
assert roll_fib_invalidate("long", 110.0, 105.0, 95.0) is True
|
||||
|
||||
|
||||
def test_preview_breakout_mode_label():
|
||||
preview, err = preview_roll(
|
||||
direction="long",
|
||||
symbol="ETH/USDT",
|
||||
qty_existing=1.0,
|
||||
entry_existing=3000.0,
|
||||
initial_take_profit=3500.0,
|
||||
add_mode="breakout",
|
||||
new_stop_loss=2980.0,
|
||||
breakthrough_price=3100.0,
|
||||
risk_percent=10.0,
|
||||
capital_base_usdt=1000.0,
|
||||
add_price=3050.0,
|
||||
)
|
||||
assert err is None
|
||||
assert preview["add_mode_label"] == "突破加仓"
|
||||
|
||||
|
||||
def test_breakout_geometry_short_mark_above_breakout():
|
||||
err = validate_roll_geometry(
|
||||
"short",
|
||||
"breakout",
|
||||
new_stop_loss=568.0,
|
||||
breakthrough_price=551.0,
|
||||
entry_existing=560.0,
|
||||
initial_take_profit=540.0,
|
||||
mark_price=560.0,
|
||||
)
|
||||
assert err is None
|
||||
|
||||
|
||||
def test_breakout_geometry_short_rejects_mark_at_or_below_breakout():
|
||||
err = validate_roll_geometry(
|
||||
"short",
|
||||
"breakout",
|
||||
new_stop_loss=568.0,
|
||||
breakthrough_price=551.0,
|
||||
entry_existing=560.0,
|
||||
initial_take_profit=540.0,
|
||||
mark_price=551.0,
|
||||
)
|
||||
assert err is not None
|
||||
assert "高于突破价" in err
|
||||
|
||||
|
||||
def test_breakout_geometry_long_rejects_mark_at_or_above_breakout():
|
||||
err = validate_roll_geometry(
|
||||
"long",
|
||||
"breakout",
|
||||
new_stop_loss=2980.0,
|
||||
breakthrough_price=3100.0,
|
||||
entry_existing=3000.0,
|
||||
initial_take_profit=3500.0,
|
||||
mark_price=3100.0,
|
||||
)
|
||||
assert err is not None
|
||||
assert "低于突破价" in err
|
||||
|
||||
@@ -1,44 +1,44 @@
|
||||
"""strategy_roll_ui_lib 单元测试。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import strategy_roll_ui_lib as roll_ui
|
||||
|
||||
|
||||
def test_compute_roll_chain_metrics_short():
|
||||
group = {
|
||||
"id": 1,
|
||||
"direction": "short",
|
||||
"initial_take_profit": 60.0,
|
||||
}
|
||||
legs = [
|
||||
{"id": 10, "leg_index": 1, "amount": 3.0, "fill_price": 65.0, "status": "filled"},
|
||||
{"id": 11, "leg_index": 2, "amount": 5.0, "fill_price": 64.0, "status": "filled"},
|
||||
]
|
||||
per_leg, group_metrics = roll_ui.compute_roll_chain_metrics(
|
||||
group,
|
||||
legs,
|
||||
qty_live=8.0,
|
||||
entry_live=63.5,
|
||||
monitor={"trigger_price": 66.0, "order_amount": 3.0},
|
||||
)
|
||||
assert per_leg[10]["avg_entry_after"] is not None
|
||||
assert per_leg[11]["avg_entry_after"] is not None
|
||||
assert group_metrics["reward_at_tp_usdt"] is not None
|
||||
assert per_leg[11]["reward_at_tp_usdt"] >= per_leg[10]["reward_at_tp_usdt"]
|
||||
|
||||
|
||||
def test_infer_initial_position_from_live():
|
||||
legs = [{"amount": 2.0, "fill_price": 64.0, "status": "filled"}]
|
||||
q0, e0 = roll_ui.infer_initial_position(5.0, 63.0, legs)
|
||||
assert q0 == 3.0
|
||||
assert abs(e0 - 62.3333333333) < 0.001
|
||||
|
||||
|
||||
def test_reward_at_tp_long():
|
||||
assert roll_ui.reward_at_tp_usdt("long", 100.0, 110.0, 2.0) == 20.0
|
||||
"""strategy_roll_ui_lib 单元测试。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import lib.strategy.strategy_roll_ui_lib as roll_ui
|
||||
|
||||
|
||||
def test_compute_roll_chain_metrics_short():
|
||||
group = {
|
||||
"id": 1,
|
||||
"direction": "short",
|
||||
"initial_take_profit": 60.0,
|
||||
}
|
||||
legs = [
|
||||
{"id": 10, "leg_index": 1, "amount": 3.0, "fill_price": 65.0, "status": "filled"},
|
||||
{"id": 11, "leg_index": 2, "amount": 5.0, "fill_price": 64.0, "status": "filled"},
|
||||
]
|
||||
per_leg, group_metrics = roll_ui.compute_roll_chain_metrics(
|
||||
group,
|
||||
legs,
|
||||
qty_live=8.0,
|
||||
entry_live=63.5,
|
||||
monitor={"trigger_price": 66.0, "order_amount": 3.0},
|
||||
)
|
||||
assert per_leg[10]["avg_entry_after"] is not None
|
||||
assert per_leg[11]["avg_entry_after"] is not None
|
||||
assert group_metrics["reward_at_tp_usdt"] is not None
|
||||
assert per_leg[11]["reward_at_tp_usdt"] >= per_leg[10]["reward_at_tp_usdt"]
|
||||
|
||||
|
||||
def test_infer_initial_position_from_live():
|
||||
legs = [{"amount": 2.0, "fill_price": 64.0, "status": "filled"}]
|
||||
q0, e0 = roll_ui.infer_initial_position(5.0, 63.0, legs)
|
||||
assert q0 == 3.0
|
||||
assert abs(e0 - 62.3333333333) < 0.001
|
||||
|
||||
|
||||
def test_reward_at_tp_long():
|
||||
assert roll_ui.reward_at_tp_usdt("long", 100.0, 110.0, 2.0) == 20.0
|
||||
|
||||
@@ -1,183 +1,183 @@
|
||||
"""策略快照:同一计划同结果不重复写入。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from strategy_snapshot_lib import ( # noqa: E402
|
||||
STRATEGY_TREND,
|
||||
dedupe_strategy_snapshots,
|
||||
init_strategy_snapshot_table,
|
||||
list_strategy_snapshots,
|
||||
save_trend_plan_snapshot,
|
||||
)
|
||||
|
||||
|
||||
def _mem_conn() -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.row_factory = sqlite3.Row
|
||||
init_strategy_snapshot_table(conn)
|
||||
return conn
|
||||
|
||||
|
||||
def test_save_trend_plan_snapshot_skips_duplicate_result():
|
||||
conn = _mem_conn()
|
||||
plan = {
|
||||
"id": 42,
|
||||
"symbol": "ONDO/USDT",
|
||||
"exchange_symbol": "ONDO/USDT:USDT",
|
||||
"direction": "short",
|
||||
"status": "active",
|
||||
"opened_at": "2026-06-08 08:00:00",
|
||||
"legs_done": 4,
|
||||
"dca_legs": 4,
|
||||
"first_order_done": 1,
|
||||
"grid_prices_json": "[]",
|
||||
"leg_amounts_json": "[]",
|
||||
}
|
||||
cfg = {"app_module": type("M", (), {"app_now_str": staticmethod(lambda: "2026-06-08 08:41:00")})()}
|
||||
save_trend_plan_snapshot(cfg, conn, plan, result_label="止损", pnl_amount=-2.3)
|
||||
save_trend_plan_snapshot(cfg, conn, plan, result_label="止损", pnl_amount=-2.4)
|
||||
conn.commit()
|
||||
rows = conn.execute(
|
||||
"SELECT COUNT(*) AS c FROM strategy_trade_snapshots WHERE source_id=? AND result_label=?",
|
||||
(42, "止损"),
|
||||
).fetchone()
|
||||
assert int(rows["c"]) == 1
|
||||
|
||||
|
||||
def test_dedupe_strategy_snapshots_handles_many_duplicates():
|
||||
conn = _mem_conn()
|
||||
payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False)
|
||||
for snap_id in range(1, 46):
|
||||
conn.execute(
|
||||
"""INSERT INTO strategy_trade_snapshots (
|
||||
id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount
|
||||
) VALUES (?,?,?,?,?,?,?,?,?)""",
|
||||
(
|
||||
snap_id,
|
||||
STRATEGY_TREND,
|
||||
99,
|
||||
"ONDO/USDT",
|
||||
"止损",
|
||||
payload,
|
||||
"2026-06-08 08:41:00",
|
||||
"2026-06-08 08:41:00",
|
||||
-2.2,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
removed = dedupe_strategy_snapshots(conn)
|
||||
conn.commit()
|
||||
assert removed == 44
|
||||
row = conn.execute(
|
||||
"SELECT COUNT(*) AS c FROM strategy_trade_snapshots WHERE source_id=?",
|
||||
(99,),
|
||||
).fetchone()
|
||||
assert int(row["c"]) == 1
|
||||
|
||||
|
||||
def test_dedupe_strategy_snapshots_keeps_latest_id():
|
||||
conn = _mem_conn()
|
||||
payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False)
|
||||
for snap_id, pnl in ((1, -2.23), (2, -2.31), (3, -2.38)):
|
||||
conn.execute(
|
||||
"""INSERT INTO strategy_trade_snapshots (
|
||||
id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount
|
||||
) VALUES (?,?,?,?,?,?,?,?,?)""",
|
||||
(
|
||||
snap_id,
|
||||
STRATEGY_TREND,
|
||||
5,
|
||||
"ONDO/USDT",
|
||||
"止损",
|
||||
payload,
|
||||
"2026-06-08 08:41:00",
|
||||
"2026-06-08 08:41:00",
|
||||
pnl,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
removed = dedupe_strategy_snapshots(conn)
|
||||
conn.commit()
|
||||
assert removed == 2
|
||||
row = conn.execute(
|
||||
"SELECT id, pnl_amount FROM strategy_trade_snapshots WHERE source_id=?",
|
||||
(5,),
|
||||
).fetchone()
|
||||
assert int(row["id"]) == 3
|
||||
assert abs(float(row["pnl_amount"]) - (-2.38)) < 1e-6
|
||||
|
||||
|
||||
def test_list_strategy_snapshots_hides_duplicate_keys():
|
||||
conn = _mem_conn()
|
||||
payload = json.dumps({"symbol": "ONDO/USDT", "dca_levels": []}, ensure_ascii=False)
|
||||
for snap_id in (10, 11, 12):
|
||||
conn.execute(
|
||||
"""INSERT INTO strategy_trade_snapshots (
|
||||
id, strategy_type, source_id, symbol, direction, result_label,
|
||||
snapshot_json, closed_at, created_at, pnl_amount
|
||||
) VALUES (?,?,?,?,?,?,?,?,?,?)""",
|
||||
(
|
||||
snap_id,
|
||||
STRATEGY_TREND,
|
||||
7,
|
||||
"ONDO/USDT",
|
||||
"short",
|
||||
"止损",
|
||||
payload,
|
||||
"2026-06-08 08:41:00",
|
||||
"2026-06-08 08:41:00",
|
||||
-2.2,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
rows = list_strategy_snapshots(conn, limit=50)
|
||||
stop_rows = [r for r in rows if int(r.get("source_id") or 0) == 7]
|
||||
assert len(stop_rows) == 1
|
||||
assert int(stop_rows[0]["id"]) == 12
|
||||
|
||||
|
||||
def test_dedupe_keeps_manual_over_stop_loss():
|
||||
conn = _mem_conn()
|
||||
payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False)
|
||||
for snap_id, label in ((10, "止损"), (11, "手动平仓")):
|
||||
conn.execute(
|
||||
"""INSERT INTO strategy_trade_snapshots (
|
||||
id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount
|
||||
) VALUES (?,?,?,?,?,?,?,?,?)""",
|
||||
(
|
||||
snap_id,
|
||||
STRATEGY_TREND,
|
||||
7,
|
||||
"ONDO/USDT",
|
||||
label,
|
||||
payload,
|
||||
"2026-06-08 08:44:00",
|
||||
"2026-06-08 08:44:00",
|
||||
-2.23,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
removed = dedupe_strategy_snapshots(conn)
|
||||
conn.commit()
|
||||
assert removed == 1
|
||||
row = conn.execute(
|
||||
"SELECT result_label FROM strategy_trade_snapshots WHERE source_id=?",
|
||||
(7,),
|
||||
).fetchone()
|
||||
assert row["result_label"] == "手动平仓"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_save_trend_plan_snapshot_skips_duplicate_result()
|
||||
test_dedupe_strategy_snapshots_handles_many_duplicates()
|
||||
test_dedupe_strategy_snapshots_keeps_latest_id()
|
||||
test_list_strategy_snapshots_hides_duplicate_keys()
|
||||
test_dedupe_keeps_manual_over_stop_loss()
|
||||
print("all ok")
|
||||
"""策略快照:同一计划同结果不重复写入。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from lib.strategy.strategy_snapshot_lib import ( # noqa: E402
|
||||
STRATEGY_TREND,
|
||||
dedupe_strategy_snapshots,
|
||||
init_strategy_snapshot_table,
|
||||
list_strategy_snapshots,
|
||||
save_trend_plan_snapshot,
|
||||
)
|
||||
|
||||
|
||||
def _mem_conn() -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.row_factory = sqlite3.Row
|
||||
init_strategy_snapshot_table(conn)
|
||||
return conn
|
||||
|
||||
|
||||
def test_save_trend_plan_snapshot_skips_duplicate_result():
|
||||
conn = _mem_conn()
|
||||
plan = {
|
||||
"id": 42,
|
||||
"symbol": "ONDO/USDT",
|
||||
"exchange_symbol": "ONDO/USDT:USDT",
|
||||
"direction": "short",
|
||||
"status": "active",
|
||||
"opened_at": "2026-06-08 08:00:00",
|
||||
"legs_done": 4,
|
||||
"dca_legs": 4,
|
||||
"first_order_done": 1,
|
||||
"grid_prices_json": "[]",
|
||||
"leg_amounts_json": "[]",
|
||||
}
|
||||
cfg = {"app_module": type("M", (), {"app_now_str": staticmethod(lambda: "2026-06-08 08:41:00")})()}
|
||||
save_trend_plan_snapshot(cfg, conn, plan, result_label="止损", pnl_amount=-2.3)
|
||||
save_trend_plan_snapshot(cfg, conn, plan, result_label="止损", pnl_amount=-2.4)
|
||||
conn.commit()
|
||||
rows = conn.execute(
|
||||
"SELECT COUNT(*) AS c FROM strategy_trade_snapshots WHERE source_id=? AND result_label=?",
|
||||
(42, "止损"),
|
||||
).fetchone()
|
||||
assert int(rows["c"]) == 1
|
||||
|
||||
|
||||
def test_dedupe_strategy_snapshots_handles_many_duplicates():
|
||||
conn = _mem_conn()
|
||||
payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False)
|
||||
for snap_id in range(1, 46):
|
||||
conn.execute(
|
||||
"""INSERT INTO strategy_trade_snapshots (
|
||||
id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount
|
||||
) VALUES (?,?,?,?,?,?,?,?,?)""",
|
||||
(
|
||||
snap_id,
|
||||
STRATEGY_TREND,
|
||||
99,
|
||||
"ONDO/USDT",
|
||||
"止损",
|
||||
payload,
|
||||
"2026-06-08 08:41:00",
|
||||
"2026-06-08 08:41:00",
|
||||
-2.2,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
removed = dedupe_strategy_snapshots(conn)
|
||||
conn.commit()
|
||||
assert removed == 44
|
||||
row = conn.execute(
|
||||
"SELECT COUNT(*) AS c FROM strategy_trade_snapshots WHERE source_id=?",
|
||||
(99,),
|
||||
).fetchone()
|
||||
assert int(row["c"]) == 1
|
||||
|
||||
|
||||
def test_dedupe_strategy_snapshots_keeps_latest_id():
|
||||
conn = _mem_conn()
|
||||
payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False)
|
||||
for snap_id, pnl in ((1, -2.23), (2, -2.31), (3, -2.38)):
|
||||
conn.execute(
|
||||
"""INSERT INTO strategy_trade_snapshots (
|
||||
id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount
|
||||
) VALUES (?,?,?,?,?,?,?,?,?)""",
|
||||
(
|
||||
snap_id,
|
||||
STRATEGY_TREND,
|
||||
5,
|
||||
"ONDO/USDT",
|
||||
"止损",
|
||||
payload,
|
||||
"2026-06-08 08:41:00",
|
||||
"2026-06-08 08:41:00",
|
||||
pnl,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
removed = dedupe_strategy_snapshots(conn)
|
||||
conn.commit()
|
||||
assert removed == 2
|
||||
row = conn.execute(
|
||||
"SELECT id, pnl_amount FROM strategy_trade_snapshots WHERE source_id=?",
|
||||
(5,),
|
||||
).fetchone()
|
||||
assert int(row["id"]) == 3
|
||||
assert abs(float(row["pnl_amount"]) - (-2.38)) < 1e-6
|
||||
|
||||
|
||||
def test_list_strategy_snapshots_hides_duplicate_keys():
|
||||
conn = _mem_conn()
|
||||
payload = json.dumps({"symbol": "ONDO/USDT", "dca_levels": []}, ensure_ascii=False)
|
||||
for snap_id in (10, 11, 12):
|
||||
conn.execute(
|
||||
"""INSERT INTO strategy_trade_snapshots (
|
||||
id, strategy_type, source_id, symbol, direction, result_label,
|
||||
snapshot_json, closed_at, created_at, pnl_amount
|
||||
) VALUES (?,?,?,?,?,?,?,?,?,?)""",
|
||||
(
|
||||
snap_id,
|
||||
STRATEGY_TREND,
|
||||
7,
|
||||
"ONDO/USDT",
|
||||
"short",
|
||||
"止损",
|
||||
payload,
|
||||
"2026-06-08 08:41:00",
|
||||
"2026-06-08 08:41:00",
|
||||
-2.2,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
rows = list_strategy_snapshots(conn, limit=50)
|
||||
stop_rows = [r for r in rows if int(r.get("source_id") or 0) == 7]
|
||||
assert len(stop_rows) == 1
|
||||
assert int(stop_rows[0]["id"]) == 12
|
||||
|
||||
|
||||
def test_dedupe_keeps_manual_over_stop_loss():
|
||||
conn = _mem_conn()
|
||||
payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False)
|
||||
for snap_id, label in ((10, "止损"), (11, "手动平仓")):
|
||||
conn.execute(
|
||||
"""INSERT INTO strategy_trade_snapshots (
|
||||
id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount
|
||||
) VALUES (?,?,?,?,?,?,?,?,?)""",
|
||||
(
|
||||
snap_id,
|
||||
STRATEGY_TREND,
|
||||
7,
|
||||
"ONDO/USDT",
|
||||
label,
|
||||
payload,
|
||||
"2026-06-08 08:44:00",
|
||||
"2026-06-08 08:44:00",
|
||||
-2.23,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
removed = dedupe_strategy_snapshots(conn)
|
||||
conn.commit()
|
||||
assert removed == 1
|
||||
row = conn.execute(
|
||||
"SELECT result_label FROM strategy_trade_snapshots WHERE source_id=?",
|
||||
(7,),
|
||||
).fetchone()
|
||||
assert row["result_label"] == "手动平仓"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_save_trend_plan_snapshot_skips_duplicate_result()
|
||||
test_dedupe_strategy_snapshots_handles_many_duplicates()
|
||||
test_dedupe_strategy_snapshots_keeps_latest_id()
|
||||
test_list_strategy_snapshots_hides_duplicate_keys()
|
||||
test_dedupe_keeps_manual_over_stop_loss()
|
||||
print("all ok")
|
||||
|
||||
@@ -1,48 +1,48 @@
|
||||
import unittest
|
||||
|
||||
from trade_exchange_stats_lib import (
|
||||
aggregate_bilateral_stats,
|
||||
commission_usdt_from_fill,
|
||||
filter_position_lifecycle_fills,
|
||||
merge_commission_prefer_income,
|
||||
quote_turnover_usdt_from_fill,
|
||||
)
|
||||
|
||||
|
||||
class TradeExchangeStatsTests(unittest.TestCase):
|
||||
def test_turnover_from_cost(self):
|
||||
t = {"cost": 1000.0, "price": 50, "amount": 20}
|
||||
self.assertEqual(quote_turnover_usdt_from_fill(t), 1000.0)
|
||||
|
||||
def test_commission_from_fee(self):
|
||||
t = {"fee": {"cost": -0.42, "currency": "USDT"}}
|
||||
self.assertEqual(commission_usdt_from_fill(t), 0.42)
|
||||
|
||||
def test_bilateral_aggregate(self):
|
||||
fills = [
|
||||
{"side": "buy", "cost": 500, "fee": {"cost": -0.2, "currency": "USDT"}, "timestamp": 1000},
|
||||
{"side": "sell", "cost": 520, "fee": {"cost": -0.21, "currency": "USDT"}, "timestamp": 2000},
|
||||
]
|
||||
stats = aggregate_bilateral_stats(fills)
|
||||
self.assertIsNotNone(stats)
|
||||
self.assertEqual(stats["exchange_turnover_usdt"], 1020.0)
|
||||
self.assertEqual(stats["exchange_commission_usdt"], 0.41)
|
||||
|
||||
def test_filter_long_lifecycle(self):
|
||||
base = 1_700_000_000_000
|
||||
trades = [
|
||||
{"side": "buy", "timestamp": base, "cost": 100},
|
||||
{"side": "sell", "timestamp": base + 60_000, "cost": 110},
|
||||
{"side": "buy", "timestamp": base + 120_000, "cost": 999},
|
||||
]
|
||||
got = filter_position_lifecycle_fills(
|
||||
trades, "long", base - 1000, base + 90_000, close_buffer_ms=0
|
||||
)
|
||||
self.assertEqual(len(got), 2)
|
||||
|
||||
def test_prefer_income_commission(self):
|
||||
self.assertEqual(merge_commission_prefer_income(0.3, 0.45), 0.45)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
import unittest
|
||||
|
||||
from lib.trade.trade_exchange_stats_lib import (
|
||||
aggregate_bilateral_stats,
|
||||
commission_usdt_from_fill,
|
||||
filter_position_lifecycle_fills,
|
||||
merge_commission_prefer_income,
|
||||
quote_turnover_usdt_from_fill,
|
||||
)
|
||||
|
||||
|
||||
class TradeExchangeStatsTests(unittest.TestCase):
|
||||
def test_turnover_from_cost(self):
|
||||
t = {"cost": 1000.0, "price": 50, "amount": 20}
|
||||
self.assertEqual(quote_turnover_usdt_from_fill(t), 1000.0)
|
||||
|
||||
def test_commission_from_fee(self):
|
||||
t = {"fee": {"cost": -0.42, "currency": "USDT"}}
|
||||
self.assertEqual(commission_usdt_from_fill(t), 0.42)
|
||||
|
||||
def test_bilateral_aggregate(self):
|
||||
fills = [
|
||||
{"side": "buy", "cost": 500, "fee": {"cost": -0.2, "currency": "USDT"}, "timestamp": 1000},
|
||||
{"side": "sell", "cost": 520, "fee": {"cost": -0.21, "currency": "USDT"}, "timestamp": 2000},
|
||||
]
|
||||
stats = aggregate_bilateral_stats(fills)
|
||||
self.assertIsNotNone(stats)
|
||||
self.assertEqual(stats["exchange_turnover_usdt"], 1020.0)
|
||||
self.assertEqual(stats["exchange_commission_usdt"], 0.41)
|
||||
|
||||
def test_filter_long_lifecycle(self):
|
||||
base = 1_700_000_000_000
|
||||
trades = [
|
||||
{"side": "buy", "timestamp": base, "cost": 100},
|
||||
{"side": "sell", "timestamp": base + 60_000, "cost": 110},
|
||||
{"side": "buy", "timestamp": base + 120_000, "cost": 999},
|
||||
]
|
||||
got = filter_position_lifecycle_fills(
|
||||
trades, "long", base - 1000, base + 90_000, close_buffer_ms=0
|
||||
)
|
||||
self.assertEqual(len(got), 2)
|
||||
|
||||
def test_prefer_income_commission(self):
|
||||
self.assertEqual(merge_commission_prefer_income(0.3, 0.45), 0.45)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,30 +1,30 @@
|
||||
from trade_result_lib import normalize_result_with_pnl, normalize_display_result, is_winning_pnl
|
||||
|
||||
|
||||
def test_stop_loss_with_profit_becomes_trailing_tp():
|
||||
assert normalize_result_with_pnl("止损", 4.33) == "移动止盈"
|
||||
|
||||
|
||||
def test_manual_close_unchanged_even_with_profit():
|
||||
assert normalize_result_with_pnl("手动平仓", 10) == "手动平仓"
|
||||
|
||||
|
||||
def test_stop_loss_with_loss_unchanged():
|
||||
assert normalize_result_with_pnl("止损", -2.5) == "止损"
|
||||
|
||||
|
||||
def test_take_profit_unchanged():
|
||||
assert normalize_result_with_pnl("止盈", 5) == "止盈"
|
||||
|
||||
|
||||
def test_external_close_becomes_manual_close():
|
||||
assert normalize_display_result("外部平仓") == "手动平仓"
|
||||
assert normalize_result_with_pnl("外部平仓", 2.5) == "手动平仓"
|
||||
assert normalize_result_with_pnl("外部平仓(自动同步)", -1) == "手动平仓"
|
||||
|
||||
|
||||
def test_winning_pnl_positive_only():
|
||||
assert is_winning_pnl(2.96) is True
|
||||
assert is_winning_pnl(0) is False
|
||||
assert is_winning_pnl(-1.05) is False
|
||||
assert is_winning_pnl(None) is False
|
||||
from lib.trade.trade_result_lib import normalize_result_with_pnl, normalize_display_result, is_winning_pnl
|
||||
|
||||
|
||||
def test_stop_loss_with_profit_becomes_trailing_tp():
|
||||
assert normalize_result_with_pnl("止损", 4.33) == "移动止盈"
|
||||
|
||||
|
||||
def test_manual_close_unchanged_even_with_profit():
|
||||
assert normalize_result_with_pnl("手动平仓", 10) == "手动平仓"
|
||||
|
||||
|
||||
def test_stop_loss_with_loss_unchanged():
|
||||
assert normalize_result_with_pnl("止损", -2.5) == "止损"
|
||||
|
||||
|
||||
def test_take_profit_unchanged():
|
||||
assert normalize_result_with_pnl("止盈", 5) == "止盈"
|
||||
|
||||
|
||||
def test_external_close_becomes_manual_close():
|
||||
assert normalize_display_result("外部平仓") == "手动平仓"
|
||||
assert normalize_result_with_pnl("外部平仓", 2.5) == "手动平仓"
|
||||
assert normalize_result_with_pnl("外部平仓(自动同步)", -1) == "手动平仓"
|
||||
|
||||
|
||||
def test_winning_pnl_positive_only():
|
||||
assert is_winning_pnl(2.96) is True
|
||||
assert is_winning_pnl(0) is False
|
||||
assert is_winning_pnl(-1.05) is False
|
||||
assert is_winning_pnl(None) is False
|
||||
|
||||
@@ -1,90 +1,90 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from trade_stats_calendar_lib import (
|
||||
build_initial_stats_calendar,
|
||||
build_stats_calendar_bootstrap,
|
||||
build_trade_stats_calendar,
|
||||
)
|
||||
|
||||
|
||||
def _row(**kwargs):
|
||||
base = {
|
||||
"monitor_type": "",
|
||||
"key_signal_type": "",
|
||||
"exchange_turnover_usdt": None,
|
||||
"exchange_commission_usdt": None,
|
||||
}
|
||||
base.update(kwargs)
|
||||
return SimpleNamespace(**base)
|
||||
|
||||
|
||||
def _matches_all(row, segment_key):
|
||||
return segment_key == "all"
|
||||
|
||||
|
||||
def _matches_manual(row, segment_key):
|
||||
if segment_key == "all":
|
||||
return True
|
||||
if segment_key == "manual":
|
||||
return (row.monitor_type or "").strip() == "手动" and not (row.key_signal_type or "").strip()
|
||||
return False
|
||||
|
||||
|
||||
class TradeStatsCalendarLibTests(unittest.TestCase):
|
||||
def test_groups_by_trading_day_and_segment(self):
|
||||
pnls = [
|
||||
(10.0, None, "2026-06-18", _row(monitor_type="手动")),
|
||||
(-3.0, None, "2026-06-18", _row(monitor_type="手动")),
|
||||
(5.0, None, "2026-06-19", _row(monitor_type="自动", key_signal_type="箱体突破")),
|
||||
]
|
||||
payload = build_trade_stats_calendar(
|
||||
pnls,
|
||||
2026,
|
||||
6,
|
||||
"manual",
|
||||
_matches_manual,
|
||||
reset_hour=8,
|
||||
)
|
||||
self.assertEqual(payload["month"], 6)
|
||||
self.assertEqual(payload["month_open_count"], 2)
|
||||
days = payload["days"]
|
||||
self.assertIn("2026-06-18", days)
|
||||
self.assertNotIn("2026-06-19", days)
|
||||
self.assertEqual(days["2026-06-18"]["open_count"], 2)
|
||||
self.assertAlmostEqual(days["2026-06-18"]["pnl_total"], 7.0)
|
||||
|
||||
def test_invalid_month_raises(self):
|
||||
with self.assertRaises(ValueError):
|
||||
build_trade_stats_calendar([], 2026, 13, "all", _matches_all)
|
||||
|
||||
def test_initial_calendar_uses_current_month(self):
|
||||
pnls = [(2.5, None, "2026-06-20", _row())]
|
||||
payload = build_initial_stats_calendar(
|
||||
pnls,
|
||||
datetime(2026, 6, 26, 12, 0),
|
||||
_matches_all,
|
||||
reset_hour=8,
|
||||
)
|
||||
self.assertEqual(payload["year"], 2026)
|
||||
self.assertEqual(payload["month"], 6)
|
||||
self.assertEqual(payload["month_open_count"], 1)
|
||||
self.assertIn("2026-06-20", payload["days"])
|
||||
|
||||
def test_bootstrap_json_roundtrip(self):
|
||||
pnls = [(2.5, None, "2026-06-20", _row())]
|
||||
payload, raw = build_stats_calendar_bootstrap(
|
||||
pnls,
|
||||
datetime(2026, 6, 26, 12, 0),
|
||||
_matches_all,
|
||||
reset_hour=8,
|
||||
)
|
||||
self.assertIsNotNone(payload)
|
||||
self.assertIsNotNone(raw)
|
||||
self.assertIn('"month_open_count":1', raw.replace(" ", ""))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from lib.trade.trade_stats_calendar_lib import (
|
||||
build_initial_stats_calendar,
|
||||
build_stats_calendar_bootstrap,
|
||||
build_trade_stats_calendar,
|
||||
)
|
||||
|
||||
|
||||
def _row(**kwargs):
|
||||
base = {
|
||||
"monitor_type": "",
|
||||
"key_signal_type": "",
|
||||
"exchange_turnover_usdt": None,
|
||||
"exchange_commission_usdt": None,
|
||||
}
|
||||
base.update(kwargs)
|
||||
return SimpleNamespace(**base)
|
||||
|
||||
|
||||
def _matches_all(row, segment_key):
|
||||
return segment_key == "all"
|
||||
|
||||
|
||||
def _matches_manual(row, segment_key):
|
||||
if segment_key == "all":
|
||||
return True
|
||||
if segment_key == "manual":
|
||||
return (row.monitor_type or "").strip() == "手动" and not (row.key_signal_type or "").strip()
|
||||
return False
|
||||
|
||||
|
||||
class TradeStatsCalendarLibTests(unittest.TestCase):
|
||||
def test_groups_by_trading_day_and_segment(self):
|
||||
pnls = [
|
||||
(10.0, None, "2026-06-18", _row(monitor_type="手动")),
|
||||
(-3.0, None, "2026-06-18", _row(monitor_type="手动")),
|
||||
(5.0, None, "2026-06-19", _row(monitor_type="自动", key_signal_type="箱体突破")),
|
||||
]
|
||||
payload = build_trade_stats_calendar(
|
||||
pnls,
|
||||
2026,
|
||||
6,
|
||||
"manual",
|
||||
_matches_manual,
|
||||
reset_hour=8,
|
||||
)
|
||||
self.assertEqual(payload["month"], 6)
|
||||
self.assertEqual(payload["month_open_count"], 2)
|
||||
days = payload["days"]
|
||||
self.assertIn("2026-06-18", days)
|
||||
self.assertNotIn("2026-06-19", days)
|
||||
self.assertEqual(days["2026-06-18"]["open_count"], 2)
|
||||
self.assertAlmostEqual(days["2026-06-18"]["pnl_total"], 7.0)
|
||||
|
||||
def test_invalid_month_raises(self):
|
||||
with self.assertRaises(ValueError):
|
||||
build_trade_stats_calendar([], 2026, 13, "all", _matches_all)
|
||||
|
||||
def test_initial_calendar_uses_current_month(self):
|
||||
pnls = [(2.5, None, "2026-06-20", _row())]
|
||||
payload = build_initial_stats_calendar(
|
||||
pnls,
|
||||
datetime(2026, 6, 26, 12, 0),
|
||||
_matches_all,
|
||||
reset_hour=8,
|
||||
)
|
||||
self.assertEqual(payload["year"], 2026)
|
||||
self.assertEqual(payload["month"], 6)
|
||||
self.assertEqual(payload["month_open_count"], 1)
|
||||
self.assertIn("2026-06-20", payload["days"])
|
||||
|
||||
def test_bootstrap_json_roundtrip(self):
|
||||
pnls = [(2.5, None, "2026-06-20", _row())]
|
||||
payload, raw = build_stats_calendar_bootstrap(
|
||||
pnls,
|
||||
datetime(2026, 6, 26, 12, 0),
|
||||
_matches_all,
|
||||
reset_hour=8,
|
||||
)
|
||||
self.assertIsNotNone(payload)
|
||||
self.assertIsNotNone(raw)
|
||||
self.assertIn('"month_open_count":1', raw.replace(" ", ""))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,101 +1,101 @@
|
||||
"""趋势回调运行中计划:实际成交价重算补仓表与金额盈亏比。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from strategy_snapshot_lib import attach_trend_dca_levels # noqa: E402
|
||||
from strategy_trend_lib import ( # noqa: E402
|
||||
calc_trend_plan_money_metrics,
|
||||
trend_leg_display_price,
|
||||
)
|
||||
|
||||
|
||||
class TestTrendDcaEnrichFills(unittest.TestCase):
|
||||
def _base_plan(self, **overrides):
|
||||
plan = {
|
||||
"direction": "long",
|
||||
"stop_loss": 0.329,
|
||||
"take_profit": 0.476,
|
||||
"first_order_amount": 115,
|
||||
"snapshot_available_usdt": 97.98,
|
||||
"risk_percent": 5,
|
||||
"contract_size": 1.0,
|
||||
"grid_prices_json": json.dumps([0.3465, 0.343, 0.3395, 0.336, 0.3325]),
|
||||
"leg_amounts_json": json.dumps([23, 23, 23, 23, 23]),
|
||||
"dca_legs": 5,
|
||||
"first_order_done": 1,
|
||||
"legs_done": 0,
|
||||
"avg_entry_price": 0.3537,
|
||||
"order_amount_open": 115,
|
||||
"target_order_amount": 230,
|
||||
"leg_fill_prices_json": json.dumps([0.3537]),
|
||||
}
|
||||
plan.update(overrides)
|
||||
return plan
|
||||
|
||||
def test_header_money_rr_not_price_rr(self):
|
||||
plan = self._base_plan()
|
||||
metrics = calc_trend_plan_money_metrics(plan)
|
||||
self.assertAlmostEqual(metrics["risk_amount_u"], 4.899, places=2)
|
||||
self.assertIsNotNone(metrics["money_rr"])
|
||||
self.assertLess(metrics["money_rr"], 4.0)
|
||||
|
||||
def test_done_dca_uses_actual_fill_price(self):
|
||||
plan = self._base_plan(
|
||||
legs_done=1,
|
||||
avg_entry_price=0.3512,
|
||||
order_amount_open=138,
|
||||
leg_fill_prices_json=json.dumps([0.3537, 0.3458]),
|
||||
)
|
||||
enriched = attach_trend_dca_levels(plan)
|
||||
levels = enriched["dca_levels"]
|
||||
self.assertEqual(len(levels), 6)
|
||||
dca1 = levels[1]
|
||||
self.assertEqual(dca1["status"], "done")
|
||||
self.assertAlmostEqual(dca1["price"], 0.3458, places=4)
|
||||
self.assertIsNotNone(dca1["avg_entry"])
|
||||
self.assertIsNotNone(dca1["rr"])
|
||||
dca2 = levels[2]
|
||||
self.assertEqual(dca2["status"], "pending")
|
||||
self.assertAlmostEqual(dca2["price"], 0.343, places=4)
|
||||
|
||||
def test_missing_dca_fills_use_grid_trigger_not_inferred_price(self):
|
||||
"""缺补仓成交价时:触发价用计划网格,末档均价对齐头部,禁止反推离谱成交价。"""
|
||||
plan = self._base_plan(
|
||||
legs_done=2,
|
||||
avg_entry_price=0.3507,
|
||||
order_amount_open=161,
|
||||
leg_fill_prices_json=json.dumps([0.3436]),
|
||||
grid_prices_json=json.dumps([0.343, 0.343, 0.3395, 0.336, 0.3325]),
|
||||
)
|
||||
enriched = attach_trend_dca_levels(plan)
|
||||
levels = enriched["dca_levels"]
|
||||
dca1 = levels[1]
|
||||
dca2 = levels[2]
|
||||
self.assertEqual(dca1["status"], "done")
|
||||
self.assertAlmostEqual(dca1["price"], 0.343, places=4)
|
||||
self.assertEqual(dca2["status"], "done")
|
||||
self.assertAlmostEqual(dca2["price"], 0.343, places=4)
|
||||
self.assertAlmostEqual(dca2["avg_entry"], 0.3507, places=4)
|
||||
self.assertLess(dca2["price"], 0.36)
|
||||
|
||||
def test_display_price_never_infers_from_target_avg(self):
|
||||
"""四所共用:缺记录时只用网格,不因均价反推离谱触发价。"""
|
||||
plan = self._base_plan(
|
||||
legs_done=2,
|
||||
avg_entry_price=0.3507,
|
||||
leg_fill_prices_json=json.dumps([0.3436]),
|
||||
grid_prices_json=json.dumps([0.343, 0.343, 0.3395, 0.336, 0.3325]),
|
||||
)
|
||||
self.assertAlmostEqual(trend_leg_display_price(plan, 2), 0.343, places=4)
|
||||
self.assertLess(trend_leg_display_price(plan, 2), 0.36)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""趋势回调运行中计划:实际成交价重算补仓表与金额盈亏比。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from lib.strategy.strategy_snapshot_lib import attach_trend_dca_levels # noqa: E402
|
||||
from lib.strategy.strategy_trend_lib import ( # noqa: E402
|
||||
calc_trend_plan_money_metrics,
|
||||
trend_leg_display_price,
|
||||
)
|
||||
|
||||
|
||||
class TestTrendDcaEnrichFills(unittest.TestCase):
|
||||
def _base_plan(self, **overrides):
|
||||
plan = {
|
||||
"direction": "long",
|
||||
"stop_loss": 0.329,
|
||||
"take_profit": 0.476,
|
||||
"first_order_amount": 115,
|
||||
"snapshot_available_usdt": 97.98,
|
||||
"risk_percent": 5,
|
||||
"contract_size": 1.0,
|
||||
"grid_prices_json": json.dumps([0.3465, 0.343, 0.3395, 0.336, 0.3325]),
|
||||
"leg_amounts_json": json.dumps([23, 23, 23, 23, 23]),
|
||||
"dca_legs": 5,
|
||||
"first_order_done": 1,
|
||||
"legs_done": 0,
|
||||
"avg_entry_price": 0.3537,
|
||||
"order_amount_open": 115,
|
||||
"target_order_amount": 230,
|
||||
"leg_fill_prices_json": json.dumps([0.3537]),
|
||||
}
|
||||
plan.update(overrides)
|
||||
return plan
|
||||
|
||||
def test_header_money_rr_not_price_rr(self):
|
||||
plan = self._base_plan()
|
||||
metrics = calc_trend_plan_money_metrics(plan)
|
||||
self.assertAlmostEqual(metrics["risk_amount_u"], 4.899, places=2)
|
||||
self.assertIsNotNone(metrics["money_rr"])
|
||||
self.assertLess(metrics["money_rr"], 4.0)
|
||||
|
||||
def test_done_dca_uses_actual_fill_price(self):
|
||||
plan = self._base_plan(
|
||||
legs_done=1,
|
||||
avg_entry_price=0.3512,
|
||||
order_amount_open=138,
|
||||
leg_fill_prices_json=json.dumps([0.3537, 0.3458]),
|
||||
)
|
||||
enriched = attach_trend_dca_levels(plan)
|
||||
levels = enriched["dca_levels"]
|
||||
self.assertEqual(len(levels), 6)
|
||||
dca1 = levels[1]
|
||||
self.assertEqual(dca1["status"], "done")
|
||||
self.assertAlmostEqual(dca1["price"], 0.3458, places=4)
|
||||
self.assertIsNotNone(dca1["avg_entry"])
|
||||
self.assertIsNotNone(dca1["rr"])
|
||||
dca2 = levels[2]
|
||||
self.assertEqual(dca2["status"], "pending")
|
||||
self.assertAlmostEqual(dca2["price"], 0.343, places=4)
|
||||
|
||||
def test_missing_dca_fills_use_grid_trigger_not_inferred_price(self):
|
||||
"""缺补仓成交价时:触发价用计划网格,末档均价对齐头部,禁止反推离谱成交价。"""
|
||||
plan = self._base_plan(
|
||||
legs_done=2,
|
||||
avg_entry_price=0.3507,
|
||||
order_amount_open=161,
|
||||
leg_fill_prices_json=json.dumps([0.3436]),
|
||||
grid_prices_json=json.dumps([0.343, 0.343, 0.3395, 0.336, 0.3325]),
|
||||
)
|
||||
enriched = attach_trend_dca_levels(plan)
|
||||
levels = enriched["dca_levels"]
|
||||
dca1 = levels[1]
|
||||
dca2 = levels[2]
|
||||
self.assertEqual(dca1["status"], "done")
|
||||
self.assertAlmostEqual(dca1["price"], 0.343, places=4)
|
||||
self.assertEqual(dca2["status"], "done")
|
||||
self.assertAlmostEqual(dca2["price"], 0.343, places=4)
|
||||
self.assertAlmostEqual(dca2["avg_entry"], 0.3507, places=4)
|
||||
self.assertLess(dca2["price"], 0.36)
|
||||
|
||||
def test_display_price_never_infers_from_target_avg(self):
|
||||
"""四所共用:缺记录时只用网格,不因均价反推离谱触发价。"""
|
||||
plan = self._base_plan(
|
||||
legs_done=2,
|
||||
avg_entry_price=0.3507,
|
||||
leg_fill_prices_json=json.dumps([0.3436]),
|
||||
grid_prices_json=json.dumps([0.343, 0.343, 0.3395, 0.336, 0.3325]),
|
||||
)
|
||||
self.assertAlmostEqual(trend_leg_display_price(plan, 2), 0.343, places=4)
|
||||
self.assertLess(trend_leg_display_price(plan, 2), 0.36)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
+43
-43
@@ -1,43 +1,43 @@
|
||||
"""趋势回调:补仓触达与有效保证金估算。"""
|
||||
from strategy_trend_lib import trend_dca_level_reached, trend_effective_margin_capital
|
||||
|
||||
|
||||
def test_trend_dca_short_monotonic_up_fills_missed_legs():
|
||||
"""做空价升:旧逻辑需 last<level,价越过 0.3437 后 last 已高于该档则永不补仓。"""
|
||||
direction = "short"
|
||||
levels = [0.3413, 0.3437, 0.346, 0.3483, 0.3507]
|
||||
pf = 0.353
|
||||
filled = [lv for lv in levels if trend_dca_level_reached(direction, pf, lv)]
|
||||
assert filled == levels
|
||||
|
||||
|
||||
def test_trend_dca_short_not_before_first_level():
|
||||
direction = "short"
|
||||
assert not trend_dca_level_reached(direction, 0.336, 0.3413)
|
||||
assert trend_dca_level_reached(direction, 0.3413, 0.3413)
|
||||
|
||||
|
||||
def test_trend_dca_long_mark_below_trigger():
|
||||
direction = "long"
|
||||
assert trend_dca_level_reached(direction, 0.344, 0.3465)
|
||||
assert not trend_dca_level_reached(direction, 0.347, 0.3465)
|
||||
|
||||
|
||||
def test_trend_effective_margin_first_leg_only():
|
||||
plan = {
|
||||
"plan_margin_capital": 12.11,
|
||||
"target_order_amount": 359.0,
|
||||
"order_amount_open": 179.0,
|
||||
"first_order_amount": 179.0,
|
||||
}
|
||||
m = trend_effective_margin_capital(plan)
|
||||
assert abs(m - 12.11 * 179 / 359) < 0.01
|
||||
|
||||
|
||||
def test_trend_effective_margin_full_position():
|
||||
plan = {
|
||||
"plan_margin_capital": 12.11,
|
||||
"target_order_amount": 359.0,
|
||||
"order_amount_open": 359.0,
|
||||
}
|
||||
assert trend_effective_margin_capital(plan) == 12.11
|
||||
"""趋势回调:补仓触达与有效保证金估算。"""
|
||||
from lib.strategy.strategy_trend_lib import trend_dca_level_reached, trend_effective_margin_capital
|
||||
|
||||
|
||||
def test_trend_dca_short_monotonic_up_fills_missed_legs():
|
||||
"""做空价升:旧逻辑需 last<level,价越过 0.3437 后 last 已高于该档则永不补仓。"""
|
||||
direction = "short"
|
||||
levels = [0.3413, 0.3437, 0.346, 0.3483, 0.3507]
|
||||
pf = 0.353
|
||||
filled = [lv for lv in levels if trend_dca_level_reached(direction, pf, lv)]
|
||||
assert filled == levels
|
||||
|
||||
|
||||
def test_trend_dca_short_not_before_first_level():
|
||||
direction = "short"
|
||||
assert not trend_dca_level_reached(direction, 0.336, 0.3413)
|
||||
assert trend_dca_level_reached(direction, 0.3413, 0.3413)
|
||||
|
||||
|
||||
def test_trend_dca_long_mark_below_trigger():
|
||||
direction = "long"
|
||||
assert trend_dca_level_reached(direction, 0.344, 0.3465)
|
||||
assert not trend_dca_level_reached(direction, 0.347, 0.3465)
|
||||
|
||||
|
||||
def test_trend_effective_margin_first_leg_only():
|
||||
plan = {
|
||||
"plan_margin_capital": 12.11,
|
||||
"target_order_amount": 359.0,
|
||||
"order_amount_open": 179.0,
|
||||
"first_order_amount": 179.0,
|
||||
}
|
||||
m = trend_effective_margin_capital(plan)
|
||||
assert abs(m - 12.11 * 179 / 359) < 0.01
|
||||
|
||||
|
||||
def test_trend_effective_margin_full_position():
|
||||
plan = {
|
||||
"plan_margin_capital": 12.11,
|
||||
"target_order_amount": 359.0,
|
||||
"order_amount_open": 359.0,
|
||||
}
|
||||
assert trend_effective_margin_capital(plan) == 12.11
|
||||
|
||||
@@ -1,92 +1,92 @@
|
||||
"""趋势计划结束:须写入 trade_records(四所统一)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import sqlite3
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from strategy_trend_register import _call_insert_trade_record # noqa: E402
|
||||
|
||||
|
||||
class _GateBotLikeModule:
|
||||
"""模拟 gate_bot:曾有 trend_plan_id 但缺 entry_reason 参数。"""
|
||||
|
||||
@staticmethod
|
||||
def insert_trade_record(
|
||||
conn,
|
||||
symbol,
|
||||
monitor_type,
|
||||
direction,
|
||||
trigger_price,
|
||||
stop_loss,
|
||||
initial_stop_loss=None,
|
||||
take_profit=None,
|
||||
margin_capital=None,
|
||||
leverage=None,
|
||||
pnl_amount=0,
|
||||
hold_seconds=0,
|
||||
trade_style=None,
|
||||
risk_amount=None,
|
||||
planned_rr=None,
|
||||
actual_rr=None,
|
||||
result="",
|
||||
miss_reason=None,
|
||||
opened_at=None,
|
||||
opened_at_ms=None,
|
||||
closed_at=None,
|
||||
closed_at_ms=None,
|
||||
exchange_trade_id=None,
|
||||
trend_plan_id=None,
|
||||
):
|
||||
conn.execute(
|
||||
"INSERT INTO trade_records (symbol, monitor_type, direction, result, trend_plan_id) "
|
||||
"VALUES (?,?,?,?,?)",
|
||||
(symbol, monitor_type, direction, result, trend_plan_id),
|
||||
)
|
||||
|
||||
|
||||
class TestTrendFinalizeTradeRecord(unittest.TestCase):
|
||||
def test_call_insert_filters_unknown_entry_reason(self):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.execute(
|
||||
"CREATE TABLE trade_records (symbol TEXT, monitor_type TEXT, direction TEXT, "
|
||||
"result TEXT, trend_plan_id INTEGER)"
|
||||
)
|
||||
m = _GateBotLikeModule()
|
||||
_call_insert_trade_record(
|
||||
m,
|
||||
4,
|
||||
dict(
|
||||
conn=conn,
|
||||
symbol="ONDO/USDT",
|
||||
monitor_type="趋势回调",
|
||||
direction="long",
|
||||
trigger_price=0.35,
|
||||
stop_loss=0.329,
|
||||
result="止损",
|
||||
entry_reason="趋势回调",
|
||||
),
|
||||
)
|
||||
row = conn.execute(
|
||||
"SELECT symbol, monitor_type, trend_plan_id FROM trade_records"
|
||||
).fetchone()
|
||||
self.assertEqual(row[0], "ONDO/USDT")
|
||||
self.assertEqual(row[1], "趋势回调")
|
||||
self.assertEqual(row[2], 4)
|
||||
|
||||
def test_gate_bot_insert_accepts_entry_reason(self):
|
||||
from crypto_monitor_gate_bot import app as gate_bot_app # noqa: E402
|
||||
|
||||
sig = inspect.signature(gate_bot_app.insert_trade_record)
|
||||
self.assertIn("entry_reason", sig.parameters)
|
||||
self.assertIn("trend_plan_id", sig.parameters)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""趋势计划结束:须写入 trade_records(四所统一)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import sqlite3
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from lib.strategy.strategy_trend_register import _call_insert_trade_record # noqa: E402
|
||||
|
||||
|
||||
class _GateBotLikeModule:
|
||||
"""模拟 gate_bot:曾有 trend_plan_id 但缺 entry_reason 参数。"""
|
||||
|
||||
@staticmethod
|
||||
def insert_trade_record(
|
||||
conn,
|
||||
symbol,
|
||||
monitor_type,
|
||||
direction,
|
||||
trigger_price,
|
||||
stop_loss,
|
||||
initial_stop_loss=None,
|
||||
take_profit=None,
|
||||
margin_capital=None,
|
||||
leverage=None,
|
||||
pnl_amount=0,
|
||||
hold_seconds=0,
|
||||
trade_style=None,
|
||||
risk_amount=None,
|
||||
planned_rr=None,
|
||||
actual_rr=None,
|
||||
result="",
|
||||
miss_reason=None,
|
||||
opened_at=None,
|
||||
opened_at_ms=None,
|
||||
closed_at=None,
|
||||
closed_at_ms=None,
|
||||
exchange_trade_id=None,
|
||||
trend_plan_id=None,
|
||||
):
|
||||
conn.execute(
|
||||
"INSERT INTO trade_records (symbol, monitor_type, direction, result, trend_plan_id) "
|
||||
"VALUES (?,?,?,?,?)",
|
||||
(symbol, monitor_type, direction, result, trend_plan_id),
|
||||
)
|
||||
|
||||
|
||||
class TestTrendFinalizeTradeRecord(unittest.TestCase):
|
||||
def test_call_insert_filters_unknown_entry_reason(self):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.execute(
|
||||
"CREATE TABLE trade_records (symbol TEXT, monitor_type TEXT, direction TEXT, "
|
||||
"result TEXT, trend_plan_id INTEGER)"
|
||||
)
|
||||
m = _GateBotLikeModule()
|
||||
_call_insert_trade_record(
|
||||
m,
|
||||
4,
|
||||
dict(
|
||||
conn=conn,
|
||||
symbol="ONDO/USDT",
|
||||
monitor_type="趋势回调",
|
||||
direction="long",
|
||||
trigger_price=0.35,
|
||||
stop_loss=0.329,
|
||||
result="止损",
|
||||
entry_reason="趋势回调",
|
||||
),
|
||||
)
|
||||
row = conn.execute(
|
||||
"SELECT symbol, monitor_type, trend_plan_id FROM trade_records"
|
||||
).fetchone()
|
||||
self.assertEqual(row[0], "ONDO/USDT")
|
||||
self.assertEqual(row[1], "趋势回调")
|
||||
self.assertEqual(row[2], 4)
|
||||
|
||||
def test_gate_bot_insert_accepts_entry_reason(self):
|
||||
from crypto_monitor_gate_bot import app as gate_bot_app # noqa: E402
|
||||
|
||||
sig = inspect.signature(gate_bot_app.insert_trade_record)
|
||||
self.assertIn("entry_reason", sig.parameters)
|
||||
self.assertIn("trend_plan_id", sig.parameters)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,44 +1,44 @@
|
||||
"""趋势回调中控 enrich:补仓次数与加仓价。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from strategy_trend_register import _trend_add_leg_fields # noqa: E402
|
||||
|
||||
|
||||
class TestTrendHubEnrich(unittest.TestCase):
|
||||
def test_add_count_and_prices(self):
|
||||
mock_ex = MagicMock()
|
||||
mock_ex.price_to_precision = lambda sym, px: f"{float(px):.4f}"
|
||||
app_mod = MagicMock()
|
||||
app_mod.exchange = mock_ex
|
||||
app_mod.ensure_markets_loaded = MagicMock()
|
||||
app_mod.normalize_exchange_symbol = lambda s: s
|
||||
cfg = {"app_module": app_mod}
|
||||
raw = {
|
||||
"symbol": "ETH/USDT",
|
||||
"exchange_symbol": "ETH/USDT:USDT",
|
||||
"legs_done": 2,
|
||||
"dca_legs": 5,
|
||||
"grid_prices_json": json.dumps([1800.1, 1750.2, 1700.3]),
|
||||
"stop_loss": 1600,
|
||||
"take_profit": 2000,
|
||||
"avg_entry_price": 1820.5,
|
||||
}
|
||||
out = _trend_add_leg_fields(cfg, raw)
|
||||
self.assertEqual(out["add_count"], 2)
|
||||
self.assertEqual(out["add_count_total"], 5)
|
||||
self.assertEqual(out["add_prices"], [1800.1, 1750.2])
|
||||
self.assertEqual(len(out["add_prices_display"]), 2)
|
||||
self.assertEqual(out["stop_loss_display"], "1600.0000")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""趋势回调中控 enrich:补仓次数与加仓价。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from lib.strategy.strategy_trend_register import _trend_add_leg_fields # noqa: E402
|
||||
|
||||
|
||||
class TestTrendHubEnrich(unittest.TestCase):
|
||||
def test_add_count_and_prices(self):
|
||||
mock_ex = MagicMock()
|
||||
mock_ex.price_to_precision = lambda sym, px: f"{float(px):.4f}"
|
||||
app_mod = MagicMock()
|
||||
app_mod.exchange = mock_ex
|
||||
app_mod.ensure_markets_loaded = MagicMock()
|
||||
app_mod.normalize_exchange_symbol = lambda s: s
|
||||
cfg = {"app_module": app_mod}
|
||||
raw = {
|
||||
"symbol": "ETH/USDT",
|
||||
"exchange_symbol": "ETH/USDT:USDT",
|
||||
"legs_done": 2,
|
||||
"dca_legs": 5,
|
||||
"grid_prices_json": json.dumps([1800.1, 1750.2, 1700.3]),
|
||||
"stop_loss": 1600,
|
||||
"take_profit": 2000,
|
||||
"avg_entry_price": 1820.5,
|
||||
}
|
||||
out = _trend_add_leg_fields(cfg, raw)
|
||||
self.assertEqual(out["add_count"], 2)
|
||||
self.assertEqual(out["add_count_total"], 5)
|
||||
self.assertEqual(out["add_prices"], [1800.1, 1750.2])
|
||||
self.assertEqual(len(out["add_prices_display"]), 2)
|
||||
self.assertEqual(out["stop_loss_display"], "1600.0000")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,92 +1,92 @@
|
||||
"""四所趋势 enrich:实例与中控 monitor 字段一致。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from strategy_trend_register import ( # noqa: E402
|
||||
enrich_trend_plan,
|
||||
enrich_trend_plan_for_hub,
|
||||
)
|
||||
|
||||
|
||||
class _FakeModule:
|
||||
@staticmethod
|
||||
def normalize_exchange_symbol(sym):
|
||||
return sym
|
||||
|
||||
@staticmethod
|
||||
def ensure_markets_loaded():
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_live_position_exchange_metrics(ex_sym, direction, order_leverage=None):
|
||||
return {
|
||||
"entry_price": 0.3507,
|
||||
"mark_price": 0.3431,
|
||||
"unrealized_pnl": -1.23,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_contract_size(_ex_sym):
|
||||
return 1.0
|
||||
|
||||
|
||||
class TestTrendHubEnrichUnified(unittest.TestCase):
|
||||
def _cfg(self):
|
||||
return {
|
||||
"app_module": _FakeModule(),
|
||||
"breakeven_offset_pct": 0.3,
|
||||
"row_to_dict": lambda row: dict(row) if not isinstance(row, dict) else row,
|
||||
}
|
||||
|
||||
def _plan_row(self):
|
||||
return {
|
||||
"id": 4,
|
||||
"symbol": "ONDO/USDT:USDT",
|
||||
"exchange_symbol": "ONDO/USDT:USDT",
|
||||
"direction": "long",
|
||||
"stop_loss": 0.329,
|
||||
"take_profit": 0.476,
|
||||
"add_upper": 0.35,
|
||||
"first_order_amount": 115,
|
||||
"snapshot_available_usdt": 97.98,
|
||||
"risk_percent": 5,
|
||||
"contract_size": 1.0,
|
||||
"grid_prices_json": json.dumps([0.343, 0.343, 0.3395, 0.336, 0.3325]),
|
||||
"leg_amounts_json": json.dumps([23, 23, 23, 23, 23]),
|
||||
"dca_legs": 5,
|
||||
"first_order_done": 1,
|
||||
"legs_done": 2,
|
||||
"avg_entry_price": 0.3434,
|
||||
"order_amount_open": 161,
|
||||
"leg_fill_prices_json": json.dumps([0.3436]),
|
||||
"leverage": 10,
|
||||
"plan_margin_capital": 8.17,
|
||||
}
|
||||
|
||||
def test_hub_and_page_share_live_avg_and_dca_levels(self):
|
||||
cfg = self._cfg()
|
||||
row = self._plan_row()
|
||||
page = enrich_trend_plan(cfg, row)
|
||||
hub = enrich_trend_plan_for_hub(cfg, row)
|
||||
self.assertAlmostEqual(page["avg_entry_price"], 0.3507, places=4)
|
||||
self.assertAlmostEqual(hub["avg_entry_price"], 0.3507, places=4)
|
||||
self.assertIn("dca_levels", page)
|
||||
self.assertIn("dca_levels", hub)
|
||||
last_done = hub["dca_levels"][2]
|
||||
self.assertEqual(last_done["status"], "done")
|
||||
self.assertAlmostEqual(last_done["price"], 0.343, places=4)
|
||||
self.assertAlmostEqual(last_done["avg_entry"], 0.3507, places=4)
|
||||
self.assertLess(last_done["price"], 0.36)
|
||||
self.assertEqual(hub.get("monitor_source"), "趋势回调计划")
|
||||
self.assertEqual(hub.get("add_count"), 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""四所趋势 enrich:实例与中控 monitor 字段一致。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from lib.strategy.strategy_trend_register import ( # noqa: E402
|
||||
enrich_trend_plan,
|
||||
enrich_trend_plan_for_hub,
|
||||
)
|
||||
|
||||
|
||||
class _FakeModule:
|
||||
@staticmethod
|
||||
def normalize_exchange_symbol(sym):
|
||||
return sym
|
||||
|
||||
@staticmethod
|
||||
def ensure_markets_loaded():
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_live_position_exchange_metrics(ex_sym, direction, order_leverage=None):
|
||||
return {
|
||||
"entry_price": 0.3507,
|
||||
"mark_price": 0.3431,
|
||||
"unrealized_pnl": -1.23,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_contract_size(_ex_sym):
|
||||
return 1.0
|
||||
|
||||
|
||||
class TestTrendHubEnrichUnified(unittest.TestCase):
|
||||
def _cfg(self):
|
||||
return {
|
||||
"app_module": _FakeModule(),
|
||||
"breakeven_offset_pct": 0.3,
|
||||
"row_to_dict": lambda row: dict(row) if not isinstance(row, dict) else row,
|
||||
}
|
||||
|
||||
def _plan_row(self):
|
||||
return {
|
||||
"id": 4,
|
||||
"symbol": "ONDO/USDT:USDT",
|
||||
"exchange_symbol": "ONDO/USDT:USDT",
|
||||
"direction": "long",
|
||||
"stop_loss": 0.329,
|
||||
"take_profit": 0.476,
|
||||
"add_upper": 0.35,
|
||||
"first_order_amount": 115,
|
||||
"snapshot_available_usdt": 97.98,
|
||||
"risk_percent": 5,
|
||||
"contract_size": 1.0,
|
||||
"grid_prices_json": json.dumps([0.343, 0.343, 0.3395, 0.336, 0.3325]),
|
||||
"leg_amounts_json": json.dumps([23, 23, 23, 23, 23]),
|
||||
"dca_legs": 5,
|
||||
"first_order_done": 1,
|
||||
"legs_done": 2,
|
||||
"avg_entry_price": 0.3434,
|
||||
"order_amount_open": 161,
|
||||
"leg_fill_prices_json": json.dumps([0.3436]),
|
||||
"leverage": 10,
|
||||
"plan_margin_capital": 8.17,
|
||||
}
|
||||
|
||||
def test_hub_and_page_share_live_avg_and_dca_levels(self):
|
||||
cfg = self._cfg()
|
||||
row = self._plan_row()
|
||||
page = enrich_trend_plan(cfg, row)
|
||||
hub = enrich_trend_plan_for_hub(cfg, row)
|
||||
self.assertAlmostEqual(page["avg_entry_price"], 0.3507, places=4)
|
||||
self.assertAlmostEqual(hub["avg_entry_price"], 0.3507, places=4)
|
||||
self.assertIn("dca_levels", page)
|
||||
self.assertIn("dca_levels", hub)
|
||||
last_done = hub["dca_levels"][2]
|
||||
self.assertEqual(last_done["status"], "done")
|
||||
self.assertAlmostEqual(last_done["price"], 0.343, places=4)
|
||||
self.assertAlmostEqual(last_done["avg_entry"], 0.3507, places=4)
|
||||
self.assertLess(last_done["price"], 0.36)
|
||||
self.assertEqual(hub.get("monitor_source"), "趋势回调计划")
|
||||
self.assertEqual(hub.get("add_count"), 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,41 +1,41 @@
|
||||
"""趋势补仓下单:空 params 不得变成 None(ccxt 会报 not iterable)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from strategy_trend_exchange import trend_market_add
|
||||
|
||||
|
||||
class TestTrendMarketAddParams(unittest.TestCase):
|
||||
def test_empty_gate_params_not_passed_as_none(self):
|
||||
ex = MagicMock()
|
||||
ex.create_order.return_value = {"id": "1", "average": 0.34}
|
||||
app = MagicMock()
|
||||
app.exchange = ex
|
||||
app.ensure_markets_loaded = MagicMock()
|
||||
app.build_gate_order_params = MagicMock(return_value={})
|
||||
cfg = {"app_module": app}
|
||||
|
||||
trend_market_add(cfg, "ONDO/USDT:USDT", "long", 23, 10)
|
||||
|
||||
args = ex.create_order.call_args
|
||||
self.assertEqual(args[0][5], {})
|
||||
self.assertIsNotNone(args[0][5])
|
||||
|
||||
def test_binance_oneway_empty_params_not_passed_as_none(self):
|
||||
ex = MagicMock()
|
||||
ex.create_order.return_value = {"id": "1"}
|
||||
app = MagicMock(spec=["exchange", "ensure_markets_loaded", "build_binance_order_params"])
|
||||
app.exchange = ex
|
||||
app.ensure_markets_loaded = MagicMock()
|
||||
app.build_binance_order_params = MagicMock(return_value={})
|
||||
cfg = {"app_module": app}
|
||||
|
||||
trend_market_add(cfg, "BTC/USDT:USDT", "long", 1, 10)
|
||||
|
||||
self.assertEqual(ex.create_order.call_args[0][5], {})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""趋势补仓下单:空 params 不得变成 None(ccxt 会报 not iterable)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from lib.strategy.strategy_trend_exchange import trend_market_add
|
||||
|
||||
|
||||
class TestTrendMarketAddParams(unittest.TestCase):
|
||||
def test_empty_gate_params_not_passed_as_none(self):
|
||||
ex = MagicMock()
|
||||
ex.create_order.return_value = {"id": "1", "average": 0.34}
|
||||
app = MagicMock()
|
||||
app.exchange = ex
|
||||
app.ensure_markets_loaded = MagicMock()
|
||||
app.build_gate_order_params = MagicMock(return_value={})
|
||||
cfg = {"app_module": app}
|
||||
|
||||
trend_market_add(cfg, "ONDO/USDT:USDT", "long", 23, 10)
|
||||
|
||||
args = ex.create_order.call_args
|
||||
self.assertEqual(args[0][5], {})
|
||||
self.assertIsNotNone(args[0][5])
|
||||
|
||||
def test_binance_oneway_empty_params_not_passed_as_none(self):
|
||||
ex = MagicMock()
|
||||
ex.create_order.return_value = {"id": "1"}
|
||||
app = MagicMock(spec=["exchange", "ensure_markets_loaded", "build_binance_order_params"])
|
||||
app.exchange = ex
|
||||
app.ensure_markets_loaded = MagicMock()
|
||||
app.build_binance_order_params = MagicMock(return_value={})
|
||||
cfg = {"app_module": app}
|
||||
|
||||
trend_market_add(cfg, "BTC/USDT:USDT", "long", 1, 10)
|
||||
|
||||
self.assertEqual(ex.create_order.call_args[0][5], {})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,58 +1,58 @@
|
||||
"""趋势回调预览:止盈盈利 U、止损金额 U、金额盈亏比。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from strategy_trend_lib import ( # noqa: E402
|
||||
build_trend_preview_level_rows,
|
||||
calc_money_reward_risk_ratio,
|
||||
calc_risk_budget_usdt,
|
||||
calc_tp_profit_usdt,
|
||||
)
|
||||
|
||||
|
||||
class TestTrendPreviewTp(unittest.TestCase):
|
||||
def test_risk_budget_from_snapshot(self):
|
||||
self.assertAlmostEqual(calc_risk_budget_usdt(110.73, 5), 5.5365, places=2)
|
||||
|
||||
def test_short_profit_at_form_take_profit(self):
|
||||
profit = calc_tp_profit_usdt("short", 72.53, 66.0, 1114, 0.00167)
|
||||
self.assertIsNotNone(profit)
|
||||
self.assertGreater(profit, 0)
|
||||
rr = calc_money_reward_risk_ratio(profit, 5.5365)
|
||||
self.assertIsNotNone(rr)
|
||||
self.assertGreater(rr, 1.5)
|
||||
|
||||
def test_preview_levels_use_money_rr(self):
|
||||
preview = {
|
||||
"direction": "short",
|
||||
"live_price_ref": 72.53,
|
||||
"stop_loss": 75.5,
|
||||
"take_profit": 66.0,
|
||||
"first_order_amount": 1114,
|
||||
"snapshot_available_usdt": 110.73,
|
||||
"risk_percent": 5,
|
||||
"contract_size": 0.00167,
|
||||
"grid_prices_json": json.dumps([73.42, 73.83]),
|
||||
"leg_amounts_json": json.dumps([222, 222]),
|
||||
}
|
||||
enriched, rows = build_trend_preview_level_rows(preview)
|
||||
self.assertAlmostEqual(enriched["preview_risk_amount_u"], 5.5365, places=2)
|
||||
self.assertEqual(enriched["preview_take_profit_price"], 66.0)
|
||||
self.assertEqual(len(rows), 3)
|
||||
self.assertEqual(rows[0]["label"], "首仓")
|
||||
self.assertEqual(rows[0]["risk_u"], enriched["preview_risk_amount_u"])
|
||||
self.assertIsNotNone(rows[0]["profit_u"])
|
||||
self.assertAlmostEqual(rows[0]["rr"], rows[0]["profit_u"] / 5.5365, places=2)
|
||||
self.assertEqual(rows[1]["risk_u"], enriched["preview_risk_amount_u"])
|
||||
self.assertGreater(rows[2]["profit_u"], rows[1]["profit_u"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""趋势回调预览:止盈盈利 U、止损金额 U、金额盈亏比。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from lib.strategy.strategy_trend_lib import ( # noqa: E402
|
||||
build_trend_preview_level_rows,
|
||||
calc_money_reward_risk_ratio,
|
||||
calc_risk_budget_usdt,
|
||||
calc_tp_profit_usdt,
|
||||
)
|
||||
|
||||
|
||||
class TestTrendPreviewTp(unittest.TestCase):
|
||||
def test_risk_budget_from_snapshot(self):
|
||||
self.assertAlmostEqual(calc_risk_budget_usdt(110.73, 5), 5.5365, places=2)
|
||||
|
||||
def test_short_profit_at_form_take_profit(self):
|
||||
profit = calc_tp_profit_usdt("short", 72.53, 66.0, 1114, 0.00167)
|
||||
self.assertIsNotNone(profit)
|
||||
self.assertGreater(profit, 0)
|
||||
rr = calc_money_reward_risk_ratio(profit, 5.5365)
|
||||
self.assertIsNotNone(rr)
|
||||
self.assertGreater(rr, 1.5)
|
||||
|
||||
def test_preview_levels_use_money_rr(self):
|
||||
preview = {
|
||||
"direction": "short",
|
||||
"live_price_ref": 72.53,
|
||||
"stop_loss": 75.5,
|
||||
"take_profit": 66.0,
|
||||
"first_order_amount": 1114,
|
||||
"snapshot_available_usdt": 110.73,
|
||||
"risk_percent": 5,
|
||||
"contract_size": 0.00167,
|
||||
"grid_prices_json": json.dumps([73.42, 73.83]),
|
||||
"leg_amounts_json": json.dumps([222, 222]),
|
||||
}
|
||||
enriched, rows = build_trend_preview_level_rows(preview)
|
||||
self.assertAlmostEqual(enriched["preview_risk_amount_u"], 5.5365, places=2)
|
||||
self.assertEqual(enriched["preview_take_profit_price"], 66.0)
|
||||
self.assertEqual(len(rows), 3)
|
||||
self.assertEqual(rows[0]["label"], "首仓")
|
||||
self.assertEqual(rows[0]["risk_u"], enriched["preview_risk_amount_u"])
|
||||
self.assertIsNotNone(rows[0]["profit_u"])
|
||||
self.assertAlmostEqual(rows[0]["rr"], rows[0]["profit_u"] / 5.5365, places=2)
|
||||
self.assertEqual(rows[1]["risk_u"], enriched["preview_risk_amount_u"])
|
||||
self.assertGreater(rows[2]["profit_u"], rows[1]["profit_u"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,85 +1,85 @@
|
||||
"""触价开仓(回调/突破)关键位监控单元测试。"""
|
||||
from trigger_entry_key_monitor_lib import (
|
||||
BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE,
|
||||
CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE,
|
||||
LEGACY_TRIGGER_ENTRY_MONITOR_TYPE,
|
||||
TRIGGER_ENTRY_MONITOR_TYPES,
|
||||
TRIGGER_ENTRY_VALIDITY_HOURS,
|
||||
breakout_trigger_entry_crossed,
|
||||
check_trigger_entry_intent_limit,
|
||||
is_breakout_trigger_entry_key_monitor_type,
|
||||
is_trigger_entry_key_monitor_type,
|
||||
trigger_entry_invalidate,
|
||||
trigger_entry_reached,
|
||||
trigger_should_fire,
|
||||
validate_trigger_entry_geometry,
|
||||
)
|
||||
|
||||
|
||||
class _FakeConn:
|
||||
def execute(self, sql, params=()):
|
||||
class R:
|
||||
def fetchone(self_inner):
|
||||
return (2,)
|
||||
|
||||
return R()
|
||||
|
||||
|
||||
def test_trigger_entry_reached_long():
|
||||
assert trigger_entry_reached("long", 2049.0, 2050.0) is True
|
||||
assert trigger_entry_reached("long", 2051.0, 2050.0) is False
|
||||
|
||||
|
||||
def test_breakout_cross_long_up():
|
||||
assert breakout_trigger_entry_crossed("long", 99.0, 100.5, 100.0) is True
|
||||
assert breakout_trigger_entry_crossed("long", None, 101.0, 100.0) is True
|
||||
assert breakout_trigger_entry_crossed("long", 100.0, 100.0, 100.0) is False
|
||||
|
||||
|
||||
def test_breakout_cross_short_down():
|
||||
assert breakout_trigger_entry_crossed("short", 101.0, 99.5, 100.0) is True
|
||||
assert breakout_trigger_entry_crossed("short", None, 99.0, 100.0) is True
|
||||
|
||||
|
||||
def test_trigger_should_fire_modes():
|
||||
assert trigger_should_fire(CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE, "long", 2049.0, 2050.0) is True
|
||||
assert trigger_should_fire(BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE, "long", 100.5, 100.0, 99.0) is True
|
||||
|
||||
|
||||
def test_validate_geometry_callback_long():
|
||||
assert validate_trigger_entry_geometry("long", 2050, 2000, 2100, 2090) is None
|
||||
|
||||
|
||||
def test_validate_geometry_breakout_short_requires_mark_above_entry():
|
||||
assert (
|
||||
validate_trigger_entry_geometry(
|
||||
"short", 551, 568, 540, 560, monitor_type=BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE
|
||||
)
|
||||
is None
|
||||
)
|
||||
err = validate_trigger_entry_geometry(
|
||||
"short", 551, 568, 540, 550, monitor_type=BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE
|
||||
)
|
||||
assert err is not None
|
||||
assert "高于入场价" in err
|
||||
|
||||
|
||||
def test_invalidate_breakout_sl_side():
|
||||
assert trigger_entry_invalidate(BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE, "long", 96, 97, 110) == "sl"
|
||||
assert trigger_entry_invalidate(CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE, "long", 96, 97, 110) is None
|
||||
|
||||
|
||||
def test_intent_limit():
|
||||
ok, msg = check_trigger_entry_intent_limit(_FakeConn(), "2026-06-07", 2, 3)
|
||||
assert ok is False
|
||||
assert "意图" in msg
|
||||
|
||||
|
||||
def test_type_names():
|
||||
assert is_trigger_entry_key_monitor_type(CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE)
|
||||
assert is_trigger_entry_key_monitor_type(BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE)
|
||||
assert is_trigger_entry_key_monitor_type(LEGACY_TRIGGER_ENTRY_MONITOR_TYPE)
|
||||
assert is_breakout_trigger_entry_key_monitor_type(BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE)
|
||||
assert CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE in TRIGGER_ENTRY_MONITOR_TYPES
|
||||
assert BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE in TRIGGER_ENTRY_MONITOR_TYPES
|
||||
assert TRIGGER_ENTRY_VALIDITY_HOURS == 24
|
||||
"""触价开仓(回调/突破)关键位监控单元测试。"""
|
||||
from lib.key_monitor.trigger_entry_key_monitor_lib import (
|
||||
BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE,
|
||||
CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE,
|
||||
LEGACY_TRIGGER_ENTRY_MONITOR_TYPE,
|
||||
TRIGGER_ENTRY_MONITOR_TYPES,
|
||||
TRIGGER_ENTRY_VALIDITY_HOURS,
|
||||
breakout_trigger_entry_crossed,
|
||||
check_trigger_entry_intent_limit,
|
||||
is_breakout_trigger_entry_key_monitor_type,
|
||||
is_trigger_entry_key_monitor_type,
|
||||
trigger_entry_invalidate,
|
||||
trigger_entry_reached,
|
||||
trigger_should_fire,
|
||||
validate_trigger_entry_geometry,
|
||||
)
|
||||
|
||||
|
||||
class _FakeConn:
|
||||
def execute(self, sql, params=()):
|
||||
class R:
|
||||
def fetchone(self_inner):
|
||||
return (2,)
|
||||
|
||||
return R()
|
||||
|
||||
|
||||
def test_trigger_entry_reached_long():
|
||||
assert trigger_entry_reached("long", 2049.0, 2050.0) is True
|
||||
assert trigger_entry_reached("long", 2051.0, 2050.0) is False
|
||||
|
||||
|
||||
def test_breakout_cross_long_up():
|
||||
assert breakout_trigger_entry_crossed("long", 99.0, 100.5, 100.0) is True
|
||||
assert breakout_trigger_entry_crossed("long", None, 101.0, 100.0) is True
|
||||
assert breakout_trigger_entry_crossed("long", 100.0, 100.0, 100.0) is False
|
||||
|
||||
|
||||
def test_breakout_cross_short_down():
|
||||
assert breakout_trigger_entry_crossed("short", 101.0, 99.5, 100.0) is True
|
||||
assert breakout_trigger_entry_crossed("short", None, 99.0, 100.0) is True
|
||||
|
||||
|
||||
def test_trigger_should_fire_modes():
|
||||
assert trigger_should_fire(CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE, "long", 2049.0, 2050.0) is True
|
||||
assert trigger_should_fire(BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE, "long", 100.5, 100.0, 99.0) is True
|
||||
|
||||
|
||||
def test_validate_geometry_callback_long():
|
||||
assert validate_trigger_entry_geometry("long", 2050, 2000, 2100, 2090) is None
|
||||
|
||||
|
||||
def test_validate_geometry_breakout_short_requires_mark_above_entry():
|
||||
assert (
|
||||
validate_trigger_entry_geometry(
|
||||
"short", 551, 568, 540, 560, monitor_type=BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE
|
||||
)
|
||||
is None
|
||||
)
|
||||
err = validate_trigger_entry_geometry(
|
||||
"short", 551, 568, 540, 550, monitor_type=BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE
|
||||
)
|
||||
assert err is not None
|
||||
assert "高于入场价" in err
|
||||
|
||||
|
||||
def test_invalidate_breakout_sl_side():
|
||||
assert trigger_entry_invalidate(BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE, "long", 96, 97, 110) == "sl"
|
||||
assert trigger_entry_invalidate(CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE, "long", 96, 97, 110) is None
|
||||
|
||||
|
||||
def test_intent_limit():
|
||||
ok, msg = check_trigger_entry_intent_limit(_FakeConn(), "2026-06-07", 2, 3)
|
||||
assert ok is False
|
||||
assert "意图" in msg
|
||||
|
||||
|
||||
def test_type_names():
|
||||
assert is_trigger_entry_key_monitor_type(CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE)
|
||||
assert is_trigger_entry_key_monitor_type(BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE)
|
||||
assert is_trigger_entry_key_monitor_type(LEGACY_TRIGGER_ENTRY_MONITOR_TYPE)
|
||||
assert is_breakout_trigger_entry_key_monitor_type(BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE)
|
||||
assert CALLBACK_TRIGGER_ENTRY_MONITOR_TYPE in TRIGGER_ENTRY_MONITOR_TYPES
|
||||
assert BREAKOUT_TRIGGER_ENTRY_MONITOR_TYPE in TRIGGER_ENTRY_MONITOR_TYPES
|
||||
assert TRIGGER_ENTRY_VALIDITY_HOURS == 24
|
||||
|
||||
Reference in New Issue
Block a user