Fix migration: per-table connections and skip non-serial setval.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-07-01 08:24:27 +08:00
parent 6abe06d935
commit ef790c8a80
+33 -29
View File
@@ -38,9 +38,15 @@ def _pg_columns(pg_conn, table: str) -> list[str]:
def _reset_sequences(pg_conn, table: str, pk: str = "id") -> None:
row = pg_conn.execute(
"SELECT pg_get_serial_sequence(%s, %s) AS seq", (table, pk)
).fetchone()
seq = (row["seq"] if row else None) or None
if not seq:
return
pg_conn.execute(
f"SELECT setval(pg_get_serial_sequence('{table}', '{pk}'), "
f"COALESCE((SELECT MAX({pk}) FROM {table}), 1), true)"
f"SELECT setval(%s, COALESCE((SELECT MAX({pk}) FROM {table}), 1), true)",
(seq,),
)
@@ -57,7 +63,6 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict:
src = sqlite3.connect(src_path)
src.row_factory = sqlite3.Row
dst = connect_db()
stats: dict[str, int] = {}
tables = _sqlite_tables(src)
@@ -65,44 +70,43 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict:
try:
for table in tables:
pg_cols = _pg_columns(dst, table)
if not pg_cols:
print(f" 跳过 {table}PostgreSQL 无此表,请先运行 deploy_postgres.sh 初始化)")
continue
src_cols = [c[1] for c in src.execute(f"PRAGMA table_info({table})").fetchall()]
cols = [c for c in src_cols if c in pg_cols]
if not cols:
print(f" 跳过 {table}(无共同列)")
continue
rows = src.execute(f"SELECT {', '.join(cols)} FROM {table}").fetchall()
if dry_run:
stats[table] = len(rows)
print(f" [dry-run] {table}: {len(rows)}")
continue
if not rows:
stats[table] = 0
continue
tbl = connect_db()
try:
dst.execute(f"DELETE FROM {table}")
pg_cols = _pg_columns(tbl, table)
if not pg_cols:
print(f" 跳过 {table}PostgreSQL 无此表,请先运行 deploy_postgres.sh 初始化)")
continue
src_cols = [c[1] for c in src.execute(f"PRAGMA table_info({table})").fetchall()]
cols = [c for c in src_cols if c in pg_cols]
if not cols:
print(f" 跳过 {table}(无共同列)")
continue
rows = src.execute(f"SELECT {', '.join(cols)} FROM {table}").fetchall()
if dry_run:
stats[table] = len(rows)
print(f" [dry-run] {table}: {len(rows)}")
continue
if not rows:
stats[table] = 0
continue
tbl.execute(f"DELETE FROM {table}")
placeholders = ", ".join(["?"] * len(cols))
col_sql = ", ".join(cols)
insert_sql = f"INSERT INTO {table} ({col_sql}) VALUES ({placeholders})"
for row in rows:
dst.execute(insert_sql, tuple(row[c] for c in cols))
tbl.execute(insert_sql, tuple(row[c] for c in cols))
if "id" in cols:
try:
_reset_sequences(dst, table, "id")
except Exception:
rollback_if_postgres(dst)
dst.commit()
_reset_sequences(tbl, table, "id")
tbl.commit()
stats[table] = len(rows)
print(f" {table}: {len(rows)}")
except Exception as exc:
rollback_if_postgres(dst)
rollback_if_postgres(tbl)
raise RuntimeError(f"迁移表 {table} 失败: {exc}") from exc
finally:
tbl.close()
finally:
src.close()
dst.close()
total = sum(stats.values())
print(f"==> 完成,共迁移 {total}")