Inline account_risk_state DDL in init_db; use autocommit for migration.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-07-01 08:26:13 +08:00
parent 06fbff04a7
commit d7ea7b9e8a
2 changed files with 43 additions and 4 deletions
+18
View File
@@ -443,6 +443,24 @@ def init_db():
rollback_if_postgres(conn) rollback_if_postgres(conn)
ensure_kline_tables(conn) ensure_kline_tables(conn)
init_strategy_tables(conn) init_strategy_tables(conn)
conn.execute(
"""CREATE TABLE IF NOT EXISTS account_risk_state (
id INTEGER PRIMARY KEY CHECK (id = 1),
trading_day TEXT,
manual_close_count INTEGER DEFAULT 0,
cooloff_until_ms INTEGER,
cooloff_hours INTEGER,
daily_frozen INTEGER DEFAULT 0,
last_close_at_ms INTEGER,
updated_at TEXT
)"""
)
if not conn.execute("SELECT id FROM account_risk_state WHERE id=1").fetchone():
conn.execute(
"INSERT INTO account_risk_state (id, trading_day, manual_close_count, daily_frozen) "
"VALUES (1, '', 0, 0)"
)
conn.commit()
from risk.account_risk_lib import ensure_account_risk_schema from risk.account_risk_lib import ensure_account_risk_schema
from recommend_store import ensure_recommend_tables from recommend_store import ensure_recommend_tables
ensure_account_risk_schema(conn) ensure_account_risk_schema(conn)
+25 -4
View File
@@ -13,11 +13,34 @@ ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path: if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT)) sys.path.insert(0, str(ROOT))
try:
import psycopg
from psycopg.rows import dict_row
except ImportError:
psycopg = None # type: ignore[assignment]
dict_row = None # type: ignore[assignment]
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv(ROOT / ".env") load_dotenv(ROOT / ".env")
from db_conn import DB_PATH, connect_db, db_backend, is_postgres, rollback_if_postgres # noqa: E402 from db_conn import DB_PATH, db_backend, is_postgres # noqa: E402
def _migrate_conn():
"""迁移专用连接(autocommit,避免单表失败污染事务)。"""
if not is_postgres():
raise RuntimeError("需要 PostgreSQL")
url = (os.getenv("DATABASE_URL") or "").strip()
raw = psycopg.connect(url, row_factory=dict_row, autocommit=True)
try:
with raw.cursor() as cur:
cur.execute("SET TIME ZONE 'Asia/Shanghai'")
except Exception:
pass
from db_conn import DbConnection
return DbConnection("postgres", raw, from_pool=False)
def _sqlite_tables(conn: sqlite3.Connection) -> list[str]: def _sqlite_tables(conn: sqlite3.Connection) -> list[str]:
@@ -70,7 +93,7 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict:
try: try:
for table in tables: for table in tables:
tbl = connect_db() tbl = _migrate_conn()
try: try:
pg_cols = _pg_columns(tbl, table) pg_cols = _pg_columns(tbl, table)
if not pg_cols: if not pg_cols:
@@ -97,11 +120,9 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict:
tbl.execute(insert_sql, tuple(row[c] for c in cols)) tbl.execute(insert_sql, tuple(row[c] for c in cols))
if "id" in cols: if "id" in cols:
_reset_sequences(tbl, table, "id") _reset_sequences(tbl, table, "id")
tbl.commit()
stats[table] = len(rows) stats[table] = len(rows)
print(f" {table}: {len(rows)}") print(f" {table}: {len(rows)}")
except Exception as exc: except Exception as exc:
rollback_if_postgres(tbl)
raise RuntimeError(f"迁移表 {table} 失败: {exc}") from exc raise RuntimeError(f"迁移表 {table} 失败: {exc}") from exc
finally: finally:
tbl.close() tbl.close()