diff --git a/strategy_roll_lib.py b/strategy_roll_lib.py index f722b02..c7ab12a 100644 --- a/strategy_roll_lib.py +++ b/strategy_roll_lib.py @@ -266,13 +266,13 @@ def validate_roll_geometry( if direction == "long": if sl >= bp: return "做多:止损须低于突破价" - if mark_price is not None and float(mark_price) <= bp: - return "做多:当前价须高于突破价(等待向上突破)" + if mark_price is not None and float(mark_price) >= bp: + return "做多:当前价须低于突破价(等待向上突破)" else: if sl <= bp: return "做空:止损须高于突破价" - if mark_price is not None and float(mark_price) >= bp: - return "做空:当前价须低于突破价(等待向下突破)" + if mark_price is not None and float(mark_price) <= bp: + return "做空:当前价须高于突破价(等待向下跌破)" else: return "加仓方式无效" diff --git a/tests/test_strategy_roll_lib.py b/tests/test_strategy_roll_lib.py index a01b063..502c96d 100644 --- a/tests/test_strategy_roll_lib.py +++ b/tests/test_strategy_roll_lib.py @@ -5,6 +5,7 @@ from strategy_roll_lib import ( roll_fib_invalidate, roll_fib_trigger_crossed, solve_add_amount_for_total_risk, + validate_roll_geometry, ) @@ -64,7 +65,48 @@ def test_preview_breakout_mode_label(): breakthrough_price=3100.0, risk_percent=10.0, capital_base_usdt=1000.0, - add_price=3150.0, + add_price=3050.0, ) assert err is None assert preview["add_mode_label"] == "突破加仓" + + +def test_breakout_geometry_short_mark_above_breakout(): + err = validate_roll_geometry( + "short", + "breakout", + new_stop_loss=568.0, + breakthrough_price=551.0, + entry_existing=560.0, + initial_take_profit=540.0, + mark_price=560.0, + ) + assert err is None + + +def test_breakout_geometry_short_rejects_mark_at_or_below_breakout(): + err = validate_roll_geometry( + "short", + "breakout", + new_stop_loss=568.0, + breakthrough_price=551.0, + entry_existing=560.0, + initial_take_profit=540.0, + mark_price=551.0, + ) + assert err is not None + assert "高于突破价" in err + + +def test_breakout_geometry_long_rejects_mark_at_or_above_breakout(): + err = validate_roll_geometry( + "long", + "breakout", + new_stop_loss=2980.0, + breakthrough_price=3100.0, + entry_existing=3000.0, + initial_take_profit=3500.0, + mark_price=3100.0, + ) + assert err is not None + assert "低于突破价" in err