import sqlite3 import unittest from strategy_db import init_strategy_tables from strategy_trade_labels import ( MONITOR_TYPE_ROLL, MONITOR_TYPE_TREND_PULLBACK, count_position_limit_active_monitors, ) def _mem_conn(): conn = sqlite3.connect(":memory:") conn.row_factory = sqlite3.Row conn.execute( """CREATE TABLE order_monitors ( id INTEGER PRIMARY KEY AUTOINCREMENT, symbol TEXT, direction TEXT, status TEXT, monitor_type TEXT, key_signal_type TEXT, trend_plan_id INTEGER )""" ) init_strategy_tables(conn) return conn class PositionLimitCountTests(unittest.TestCase): def test_regular_monitor_counts(self): conn = _mem_conn() conn.execute( "INSERT INTO order_monitors (symbol, status, monitor_type) VALUES ('ETH/USDT', 'active', '下单监控')" ) conn.commit() self.assertEqual(count_position_limit_active_monitors(conn), 1) def test_trend_pullback_excluded(self): conn = _mem_conn() conn.execute( """INSERT INTO order_monitors (symbol, status, monitor_type, trend_plan_id) VALUES ('ETH/USDT', 'active', ?, 12)""", (MONITOR_TYPE_TREND_PULLBACK,), ) conn.commit() self.assertEqual(count_position_limit_active_monitors(conn), 0) def test_roll_monitor_type_excluded(self): conn = _mem_conn() conn.execute( "INSERT INTO order_monitors (symbol, status, monitor_type) VALUES ('ETH/USDT', 'active', ?)", (MONITOR_TYPE_ROLL,), ) conn.commit() self.assertEqual(count_position_limit_active_monitors(conn), 0) def test_active_roll_group_excludes_monitor(self): conn = _mem_conn() conn.execute( "INSERT INTO order_monitors (id, symbol, status, monitor_type) VALUES (1, 'ETH/USDT', 'active', '下单监控')" ) conn.execute( """INSERT INTO roll_groups (order_monitor_id, symbol, direction, status) VALUES (1, 'ETH/USDT', 'long', 'active')""" ) conn.commit() self.assertEqual(count_position_limit_active_monitors(conn), 0) def test_mixed_monitors(self): conn = _mem_conn() conn.execute( "INSERT INTO order_monitors (symbol, status, monitor_type) VALUES ('BTC/USDT', 'active', '下单监控')" ) conn.execute( """INSERT INTO order_monitors (symbol, status, monitor_type, trend_plan_id) VALUES ('ETH/USDT', 'active', ?, 3)""", (MONITOR_TYPE_TREND_PULLBACK,), ) conn.commit() self.assertEqual(count_position_limit_active_monitors(conn), 1) if __name__ == "__main__": unittest.main()