diff --git a/app.py b/app.py index 3caa530..31047b9 100644 --- a/app.py +++ b/app.py @@ -443,6 +443,24 @@ def init_db(): rollback_if_postgres(conn) ensure_kline_tables(conn) init_strategy_tables(conn) + conn.execute( + """CREATE TABLE IF NOT EXISTS account_risk_state ( + id INTEGER PRIMARY KEY CHECK (id = 1), + trading_day TEXT, + manual_close_count INTEGER DEFAULT 0, + cooloff_until_ms INTEGER, + cooloff_hours INTEGER, + daily_frozen INTEGER DEFAULT 0, + last_close_at_ms INTEGER, + updated_at TEXT + )""" + ) + if not conn.execute("SELECT id FROM account_risk_state WHERE id=1").fetchone(): + conn.execute( + "INSERT INTO account_risk_state (id, trading_day, manual_close_count, daily_frozen) " + "VALUES (1, '', 0, 0)" + ) + conn.commit() from risk.account_risk_lib import ensure_account_risk_schema from recommend_store import ensure_recommend_tables ensure_account_risk_schema(conn) diff --git a/scripts/migrate_sqlite_to_postgres.py b/scripts/migrate_sqlite_to_postgres.py index bf338f5..1f88c6e 100644 --- a/scripts/migrate_sqlite_to_postgres.py +++ b/scripts/migrate_sqlite_to_postgres.py @@ -13,11 +13,34 @@ ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) +try: + import psycopg + from psycopg.rows import dict_row +except ImportError: + psycopg = None # type: ignore[assignment] + dict_row = None # type: ignore[assignment] + from dotenv import load_dotenv load_dotenv(ROOT / ".env") -from db_conn import DB_PATH, connect_db, db_backend, is_postgres, rollback_if_postgres # noqa: E402 +from db_conn import DB_PATH, db_backend, is_postgres # noqa: E402 + + +def _migrate_conn(): + """迁移专用连接(autocommit,避免单表失败污染事务)。""" + if not is_postgres(): + raise RuntimeError("需要 PostgreSQL") + url = (os.getenv("DATABASE_URL") or "").strip() + raw = psycopg.connect(url, row_factory=dict_row, autocommit=True) + try: + with raw.cursor() as cur: + cur.execute("SET TIME ZONE 'Asia/Shanghai'") + except Exception: + pass + from db_conn import DbConnection + + return DbConnection("postgres", raw, from_pool=False) def _sqlite_tables(conn: sqlite3.Connection) -> list[str]: @@ -70,7 +93,7 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict: try: for table in tables: - tbl = connect_db() + tbl = _migrate_conn() try: pg_cols = _pg_columns(tbl, table) if not pg_cols: @@ -97,11 +120,9 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict: tbl.execute(insert_sql, tuple(row[c] for c in cols)) if "id" in cols: _reset_sequences(tbl, table, "id") - tbl.commit() stats[table] = len(rows) print(f" {table}: {len(rows)} 行") except Exception as exc: - rollback_if_postgres(tbl) raise RuntimeError(f"迁移表 {table} 失败: {exc}") from exc finally: tbl.close()