diff --git a/scripts/migrate_sqlite_to_postgres.py b/scripts/migrate_sqlite_to_postgres.py index 720af8a..66aec55 100644 --- a/scripts/migrate_sqlite_to_postgres.py +++ b/scripts/migrate_sqlite_to_postgres.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # Copyright (c) 2025-2026 马建军. All rights reserved. -"""将 SQLite futures.db 迁移到 PostgreSQL(需已配置 DATABASE_URL 并 init 空库)。""" +"""将 SQLite futures.db 迁移到 PostgreSQL(需已 init 表结构,见 deploy_postgres.sh)。""" from __future__ import annotations import argparse @@ -17,7 +17,7 @@ from dotenv import load_dotenv load_dotenv(ROOT / ".env") -from db_conn import DB_PATH, connect_db, db_backend, is_postgres # noqa: E402 +from db_conn import DB_PATH, connect_db, db_backend, is_postgres, rollback_if_postgres # noqa: E402 def _sqlite_tables(conn: sqlite3.Connection) -> list[str]: @@ -40,14 +40,12 @@ def _pg_columns(pg_conn, table: str) -> list[str]: def _reset_sequences(pg_conn, table: str, pk: str = "id") -> None: try: pg_conn.execute( - f"""SELECT setval( - pg_get_serial_sequence('{table}', '{pk}'), - COALESCE((SELECT MAX({pk}) FROM {table}), 1), - true - )""" + f"SELECT setval(pg_get_serial_sequence('{table}', '{pk}'), " + f"COALESCE((SELECT MAX({pk}) FROM {table}), 1), true)" ) + pg_conn.commit() except Exception: - pass + rollback_if_postgres(pg_conn) def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict: @@ -61,12 +59,6 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict: print(f"==> 源库: {src_path}") print(f"==> 目标: PostgreSQL ({os.getenv('DATABASE_URL', '').split('@')[-1]})") - if not dry_run: - print("==> 初始化 PostgreSQL 表结构...") - from app import init_db - - init_db() - src = sqlite3.connect(src_path) src.row_factory = sqlite3.Row dst = connect_db() @@ -79,7 +71,7 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict: for table in tables: pg_cols = _pg_columns(dst, table) if not pg_cols: - print(f" 跳过 {table}(PostgreSQL 无此表,请先 init_db)") + print(f" 跳过 {table}(PostgreSQL 无此表,请先运行 deploy_postgres.sh 初始化)") continue src_cols = [c[1] for c in src.execute(f"PRAGMA table_info({table})").fetchall()] cols = [c for c in src_cols if c in pg_cols] @@ -94,18 +86,22 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict: if not rows: stats[table] = 0 continue - dst.execute(f"DELETE FROM {table}") - placeholders = ", ".join(["?"] * len(cols)) - col_sql = ", ".join(cols) - insert_sql = f"INSERT INTO {table} ({col_sql}) VALUES ({placeholders})" - for row in rows: - dst.execute(insert_sql, tuple(row[c] for c in cols)) - stats[table] = len(rows) - if "id" in cols: - _reset_sequences(dst, table, "id") - print(f" {table}: {len(rows)} 行") - if not dry_run: - dst.commit() + try: + dst.execute(f"DELETE FROM {table}") + placeholders = ", ".join(["?"] * len(cols)) + col_sql = ", ".join(cols) + insert_sql = f"INSERT INTO {table} ({col_sql}) VALUES ({placeholders})" + for row in rows: + dst.execute(insert_sql, tuple(row[c] for c in cols)) + if "id" in cols: + _reset_sequences(dst, table, "id") + else: + dst.commit() + stats[table] = len(rows) + print(f" {table}: {len(rows)} 行") + except Exception as exc: + rollback_if_postgres(dst) + raise RuntimeError(f"迁移表 {table} 失败: {exc}") from exc finally: src.close() dst.close()