import unittest from daily_open_limit_lib import ( build_daily_open_alert_prompt, can_trade_new_open, check_daily_open_hard_limit, count_opens_for_trading_day, daily_open_hard_limit_blocks, format_daily_open_counter_line, hard_limit_block_reason, load_daily_open_limits_from_env, parse_daily_open_hard_limit, should_send_daily_open_alert, ) class _FakeConn: def __init__(self, count: int): self._count = count def execute(self, _sql, _params): return self def fetchone(self): return (self._count,) class DailyOpenLimitLibTests(unittest.TestCase): def test_parse_hard_limit_zero_disables(self): self.assertEqual(parse_daily_open_hard_limit("0"), 0) self.assertEqual(parse_daily_open_hard_limit(None, default=0), 0) def test_load_from_env(self): alert, hard = load_daily_open_limits_from_env( {"DAILY_OPEN_ALERT_THRESHOLD": "3", "DAILY_OPEN_HARD_LIMIT": "8"} ) self.assertEqual(alert, 3) self.assertEqual(hard, 8) def test_hard_limit_blocks(self): self.assertFalse(daily_open_hard_limit_blocks(4, 0)) self.assertFalse(daily_open_hard_limit_blocks(4, 5)) self.assertTrue(daily_open_hard_limit_blocks(5, 5)) def test_check_daily_open_hard_limit(self): conn = _FakeConn(5) ok, reason, n = check_daily_open_hard_limit(conn, "2026-06-07", 5, 8) self.assertFalse(ok) self.assertEqual(n, 5) self.assertIn("已达上限", reason) self.assertIn("8:00", reason) def test_count_opens(self): self.assertEqual(count_opens_for_trading_day(_FakeConn(3), "2026-06-07"), 3) def test_can_trade_new_open(self): self.assertTrue( can_trade_new_open( time_allows=True, active_count=0, max_active_positions=1, opens_today=2, hard_limit=5, ) ) self.assertFalse( can_trade_new_open( time_allows=True, active_count=0, max_active_positions=1, opens_today=5, hard_limit=5, ) ) def test_alert_crossing(self): self.assertTrue(should_send_daily_open_alert(4, 5, 5)) self.assertFalse(should_send_daily_open_alert(5, 6, 5)) def test_prompt_includes_hard_limit(self): txt = build_daily_open_alert_prompt("2026-06-07", 5, 5, hard_limit=8) self.assertIn("硬上限 8", txt) def test_counter_line(self): line = format_daily_open_counter_line(3, 5, 8) self.assertIn("3 / 硬上限 8", line) if __name__ == "__main__": unittest.main()