From ef790c8a800031671576d80e7ad42c83723ee8f6 Mon Sep 17 00:00:00 2001 From: dekun Date: Wed, 1 Jul 2026 08:24:27 +0800 Subject: [PATCH] Fix migration: per-table connections and skip non-serial setval. Co-authored-by: Cursor --- scripts/migrate_sqlite_to_postgres.py | 62 ++++++++++++++------------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/scripts/migrate_sqlite_to_postgres.py b/scripts/migrate_sqlite_to_postgres.py index cfdd579..bf338f5 100644 --- a/scripts/migrate_sqlite_to_postgres.py +++ b/scripts/migrate_sqlite_to_postgres.py @@ -38,9 +38,15 @@ def _pg_columns(pg_conn, table: str) -> list[str]: def _reset_sequences(pg_conn, table: str, pk: str = "id") -> None: + row = pg_conn.execute( + "SELECT pg_get_serial_sequence(%s, %s) AS seq", (table, pk) + ).fetchone() + seq = (row["seq"] if row else None) or None + if not seq: + return pg_conn.execute( - f"SELECT setval(pg_get_serial_sequence('{table}', '{pk}'), " - f"COALESCE((SELECT MAX({pk}) FROM {table}), 1), true)" + f"SELECT setval(%s, COALESCE((SELECT MAX({pk}) FROM {table}), 1), true)", + (seq,), ) @@ -57,7 +63,6 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict: src = sqlite3.connect(src_path) src.row_factory = sqlite3.Row - dst = connect_db() stats: dict[str, int] = {} tables = _sqlite_tables(src) @@ -65,44 +70,43 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict: try: for table in tables: - pg_cols = _pg_columns(dst, table) - if not pg_cols: - print(f" 跳过 {table}(PostgreSQL 无此表,请先运行 deploy_postgres.sh 初始化)") - continue - src_cols = [c[1] for c in src.execute(f"PRAGMA table_info({table})").fetchall()] - cols = [c for c in src_cols if c in pg_cols] - if not cols: - print(f" 跳过 {table}(无共同列)") - continue - rows = src.execute(f"SELECT {', '.join(cols)} FROM {table}").fetchall() - if dry_run: - stats[table] = len(rows) - print(f" [dry-run] {table}: {len(rows)} 行") - continue - if not rows: - stats[table] = 0 - continue + tbl = connect_db() try: - dst.execute(f"DELETE FROM {table}") + pg_cols = _pg_columns(tbl, table) + if not pg_cols: + print(f" 跳过 {table}(PostgreSQL 无此表,请先运行 deploy_postgres.sh 初始化)") + continue + src_cols = [c[1] for c in src.execute(f"PRAGMA table_info({table})").fetchall()] + cols = [c for c in src_cols if c in pg_cols] + if not cols: + print(f" 跳过 {table}(无共同列)") + continue + rows = src.execute(f"SELECT {', '.join(cols)} FROM {table}").fetchall() + if dry_run: + stats[table] = len(rows) + print(f" [dry-run] {table}: {len(rows)} 行") + continue + if not rows: + stats[table] = 0 + continue + tbl.execute(f"DELETE FROM {table}") placeholders = ", ".join(["?"] * len(cols)) col_sql = ", ".join(cols) insert_sql = f"INSERT INTO {table} ({col_sql}) VALUES ({placeholders})" for row in rows: - dst.execute(insert_sql, tuple(row[c] for c in cols)) + tbl.execute(insert_sql, tuple(row[c] for c in cols)) if "id" in cols: - try: - _reset_sequences(dst, table, "id") - except Exception: - rollback_if_postgres(dst) - dst.commit() + _reset_sequences(tbl, table, "id") + tbl.commit() stats[table] = len(rows) print(f" {table}: {len(rows)} 行") except Exception as exc: - rollback_if_postgres(dst) + rollback_if_postgres(tbl) raise RuntimeError(f"迁移表 {table} 失败: {exc}") from exc + finally: + tbl.close() finally: src.close() - dst.close() total = sum(stats.values()) print(f"==> 完成,共迁移 {total} 行")