diff --git a/strategy_trend_lib.py b/strategy_trend_lib.py index 1bca74d..9c0f712 100644 --- a/strategy_trend_lib.py +++ b/strategy_trend_lib.py @@ -359,48 +359,46 @@ def append_leg_fill_price_json(existing_json: str | None, fill_px: float) -> str return json.dumps(fills, ensure_ascii=False, separators=(",", ":")) -def _trend_leg_contracts( - leg_idx: int, first_amt: float, leg_amounts: list[float] -) -> float: +def trend_leg_grid_price(plan: dict, leg_idx: int) -> Optional[float]: + """补仓 leg_idx(1..N) 的计划网格触发价;首仓返回 None。""" + if leg_idx <= 0: + return None + try: + grid = [float(x) for x in json.loads((plan or {}).get("grid_prices_json") or "[]")] + except Exception: + grid = [] + gi = leg_idx - 1 + if 0 <= gi < len(grid): + return float(grid[gi]) + return None + + +def trend_leg_display_price(plan: dict, leg_idx: int) -> Optional[float]: + """ + 四所统一:单档展示价 = leg_fill_prices_json 实际记录,否则计划网格(首仓用均价/参考价)。 + 禁止为凑均价反推虚构成交价。 + """ + p = plan or {} + fills = parse_leg_fill_prices(p) + if len(fills) > leg_idx: + return float(fills[leg_idx]) if leg_idx == 0: - return float(first_amt) - li = leg_idx - 1 - if 0 <= li < len(leg_amounts): - return float(leg_amounts[li]) - return 0.0 - - -def _infer_trend_fill_from_target_avg( - leg_idx: int, - prices: list[float], - *, - first_amt: float, - leg_amounts: list[float], - legs_done: int, - target_avg: float, -) -> float: - """已知其余档位成交价时,反推单档成交价使加权均价等于 target_avg。""" - total_contracts = 0.0 - known_cost = 0.0 - unknown_amt = _trend_leg_contracts(leg_idx, first_amt, leg_amounts) - for i in range(legs_done + 1): - amt = _trend_leg_contracts(i, first_amt, leg_amounts) - if amt <= 0: - continue - total_contracts += amt - if i == leg_idx: - continue - known_cost += float(prices[i]) * amt - if unknown_amt <= 0 or total_contracts <= 0: - return float(prices[leg_idx]) - return (float(target_avg) * total_contracts - known_cost) / unknown_amt + try: + return float(p.get("avg_entry_price")) + except (TypeError, ValueError): + pass + try: + ref = p.get("live_price_ref") + if ref not in (None, ""): + return float(ref) + except (TypeError, ValueError): + pass + return None + return trend_leg_grid_price(p, leg_idx) def reconcile_trend_leg_fill_prices(plan: dict) -> list[float]: - """ - 首仓(0)+已补仓(1..legs_done) 成交价。 - 优先 leg_fill_prices_json;缺口用计划网格价;再对齐 avg_entry_price。 - """ + """首仓(0)+已补仓(1..legs_done) 展示价列表(四所共用 trend_leg_display_price)。""" p = plan or {} if int(p.get("first_order_done") or 0) == 0: return [] @@ -408,64 +406,10 @@ def reconcile_trend_leg_fill_prices(plan: dict) -> list[float]: legs_done = int(p.get("legs_done") or 0) except (TypeError, ValueError): legs_done = 0 - try: - first_amt = float(p.get("first_order_amount")) - except (TypeError, ValueError): - first_amt = 0.0 - try: - target_avg = float(p.get("avg_entry_price")) - except (TypeError, ValueError): - target_avg = None - - fills = parse_leg_fill_prices(p) - try: - grid = [float(x) for x in json.loads(p.get("grid_prices_json") or "[]")] - except Exception: - grid = [] - try: - leg_amounts = [float(x) for x in json.loads(p.get("leg_amounts_json") or "[]")] - except Exception: - leg_amounts = [] - - def _default_px(leg_idx: int) -> float: - if leg_idx == 0: - if target_avg is not None and legs_done == 0: - return target_avg - try: - return float(p.get("avg_entry_price")) - except (TypeError, ValueError): - pass - try: - ref = p.get("live_price_ref") - if ref not in (None, ""): - return float(ref) - except (TypeError, ValueError): - pass - return 0.0 - gi = leg_idx - 1 - if 0 <= gi < len(grid): - return float(grid[gi]) - return _default_px(0) - result: list[float] = [] - estimated: list[int] = [] for leg_idx in range(legs_done + 1): - if len(fills) > leg_idx: - result.append(float(fills[leg_idx])) - else: - result.append(_default_px(leg_idx)) - estimated.append(leg_idx) - - if target_avg is not None and estimated: - adjust_idx = estimated[0] if len(estimated) == 1 else estimated[-1] - result[adjust_idx] = _infer_trend_fill_from_target_avg( - adjust_idx, - result, - first_amt=first_amt, - leg_amounts=leg_amounts, - legs_done=legs_done, - target_avg=target_avg, - ) + px = trend_leg_display_price(p, leg_idx) + result.append(float(px) if px is not None else 0.0) return result @@ -618,7 +562,10 @@ def build_trend_preview_level_rows(preview: dict) -> tuple[dict, list[dict]]: def enrich_trend_dca_levels_with_tp(plan: dict, levels: list[dict]) -> list[dict]: - """运行中计划:补仓表按实际成交价重算触发价/均价/金额盈亏比;未补档仍用计划触发价预估。""" + """ + 四所统一补仓表 enrich(实例策略页 + 中控 monitor 共用)。 + 触发价:实际成交价或计划网格;末档加仓后均价用持仓均价;禁止反推虚构成交价。 + """ if not levels: return levels p = plan or {} @@ -641,7 +588,10 @@ def enrich_trend_dca_levels_with_tp(plan: dict, levels: list[dict]) -> list[dict except (TypeError, ValueError): legs_done = 0 first_done = int(p.get("first_order_done") or 0) != 0 - reconciled_fills = reconcile_trend_leg_fill_prices(p) + try: + target_avg = float(p.get("avg_entry_price")) + except (TypeError, ValueError): + target_avg = None ref_raw = p.get("live_price_ref") if ref_raw in (None, ""): @@ -671,9 +621,8 @@ def enrich_trend_dca_levels_with_tp(plan: dict, levels: list[dict]) -> list[dict except (TypeError, ValueError): amt_f = first_amt if first_done: - if reconciled_fills: - fill_px = float(reconciled_fills[0]) - else: + fill_px = trend_leg_display_price(p, 0) + if fill_px is None: try: fill_px = float(p.get("avg_entry_price") or ref) except (TypeError, ValueError): @@ -681,8 +630,11 @@ def enrich_trend_dca_levels_with_tp(plan: dict, levels: list[dict]) -> list[dict accumulated = [(float(fill_px), amt_f)] cum_contracts = amt_f row_cum = cum_contracts - row["avg_entry"] = float(fill_px) row["price"] = fill_px + if target_avg is not None and legs_done == 0: + row["avg_entry"] = target_avg + else: + row["avg_entry"] = float(fill_px) else: accumulated = [(ref, amt_f)] cum_contracts = amt_f @@ -704,19 +656,19 @@ def enrich_trend_dca_levels_with_tp(plan: dict, levels: list[dict]) -> list[dict leg_contracts = 0.0 done = row.get("status") == "done" or (leg_num > 0 and leg_num <= legs_done) if done and leg_contracts > 0: - if leg_num < len(reconciled_fills): - fill_px = float(reconciled_fills[leg_num]) - elif grid_trigger_f is not None: - fill_px = grid_trigger_f - else: - fill_px = ref + fill_px = trend_leg_display_price(p, leg_num) + if fill_px is None: + fill_px = grid_trigger_f if grid_trigger_f is not None else ref row["price"] = fill_px accumulated.append((fill_px, leg_contracts)) cum_contracts += leg_contracts row_cum = cum_contracts - avg = weighted_avg_entry(accumulated) - if avg is not None: - row["avg_entry"] = avg + if leg_num == legs_done and target_avg is not None: + row["avg_entry"] = target_avg + else: + avg = weighted_avg_entry(accumulated) + if avg is not None: + row["avg_entry"] = avg elif grid_trigger_f is not None and leg_contracts > 0: row["price"] = grid_trigger_f projected = accumulated + [(grid_trigger_f, leg_contracts)] diff --git a/strategy_trend_register.py b/strategy_trend_register.py index bf81660..f17a1af 100644 --- a/strategy_trend_register.py +++ b/strategy_trend_register.py @@ -457,14 +457,14 @@ def _trend_add_leg_fields(cfg: dict, d: dict) -> dict: grid = [] add_prices: list[float] = [] try: - from strategy_trend_lib import reconcile_trend_leg_fill_prices + from strategy_trend_lib import trend_leg_display_price - fills = reconcile_trend_leg_fill_prices(out) for i in range(1, legs_done + 1): - if i < len(fills): - add_prices.append(float(fills[i])) + px = trend_leg_display_price(out, i) + if px is not None: + add_prices.append(float(px)) except Exception: - fills = [] + pass if not add_prices: for x in grid[:legs_done]: try: diff --git a/tests/test_trend_dca_enrich_fills.py b/tests/test_trend_dca_enrich_fills.py index 03e0e4c..68f90d5 100644 --- a/tests/test_trend_dca_enrich_fills.py +++ b/tests/test_trend_dca_enrich_fills.py @@ -10,7 +10,10 @@ ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(ROOT)) from strategy_snapshot_lib import attach_trend_dca_levels # noqa: E402 -from strategy_trend_lib import calc_trend_plan_money_metrics # noqa: E402 +from strategy_trend_lib import ( # noqa: E402 + calc_trend_plan_money_metrics, + trend_leg_display_price, +) class TestTrendDcaEnrichFills(unittest.TestCase): @@ -62,8 +65,8 @@ class TestTrendDcaEnrichFills(unittest.TestCase): self.assertEqual(dca2["status"], "pending") self.assertAlmostEqual(dca2["price"], 0.343, places=4) - def test_missing_dca_fills_align_last_avg_with_header(self): - """缺补仓成交价时,末档加仓后均价应对齐计划头部 avg_entry_price。""" + def test_missing_dca_fills_use_grid_trigger_not_inferred_price(self): + """缺补仓成交价时:触发价用计划网格,末档均价对齐头部,禁止反推离谱成交价。""" plan = self._base_plan( legs_done=2, avg_entry_price=0.3507, @@ -73,10 +76,25 @@ class TestTrendDcaEnrichFills(unittest.TestCase): ) enriched = attach_trend_dca_levels(plan) levels = enriched["dca_levels"] + dca1 = levels[1] dca2 = levels[2] + self.assertEqual(dca1["status"], "done") + self.assertAlmostEqual(dca1["price"], 0.343, places=4) self.assertEqual(dca2["status"], "done") + self.assertAlmostEqual(dca2["price"], 0.343, places=4) self.assertAlmostEqual(dca2["avg_entry"], 0.3507, places=4) - self.assertGreater(dca2["price"], 0.343) + self.assertLess(dca2["price"], 0.36) + + def test_display_price_never_infers_from_target_avg(self): + """四所共用:缺记录时只用网格,不因均价反推离谱触发价。""" + plan = self._base_plan( + legs_done=2, + avg_entry_price=0.3507, + leg_fill_prices_json=json.dumps([0.3436]), + grid_prices_json=json.dumps([0.343, 0.343, 0.3395, 0.336, 0.3325]), + ) + self.assertAlmostEqual(trend_leg_display_price(plan, 2), 0.343, places=4) + self.assertLess(trend_leg_display_price(plan, 2), 0.36) if __name__ == "__main__": diff --git a/tests/test_trend_hub_enrich_unified.py b/tests/test_trend_hub_enrich_unified.py index 3445ea0..e0079b7 100644 --- a/tests/test_trend_hub_enrich_unified.py +++ b/tests/test_trend_hub_enrich_unified.py @@ -81,7 +81,9 @@ class TestTrendHubEnrichUnified(unittest.TestCase): self.assertIn("dca_levels", hub) last_done = hub["dca_levels"][2] self.assertEqual(last_done["status"], "done") + self.assertAlmostEqual(last_done["price"], 0.343, places=4) self.assertAlmostEqual(last_done["avg_entry"], 0.3507, places=4) + self.assertLess(last_done["price"], 0.36) self.assertEqual(hub.get("monitor_source"), "趋势回调计划") self.assertEqual(hub.get("add_count"), 2)