Fix migration: per-table connections and skip non-serial setval.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -38,9 +38,15 @@ def _pg_columns(pg_conn, table: str) -> list[str]:
|
||||
|
||||
|
||||
def _reset_sequences(pg_conn, table: str, pk: str = "id") -> None:
|
||||
row = pg_conn.execute(
|
||||
"SELECT pg_get_serial_sequence(%s, %s) AS seq", (table, pk)
|
||||
).fetchone()
|
||||
seq = (row["seq"] if row else None) or None
|
||||
if not seq:
|
||||
return
|
||||
pg_conn.execute(
|
||||
f"SELECT setval(pg_get_serial_sequence('{table}', '{pk}'), "
|
||||
f"COALESCE((SELECT MAX({pk}) FROM {table}), 1), true)"
|
||||
f"SELECT setval(%s, COALESCE((SELECT MAX({pk}) FROM {table}), 1), true)",
|
||||
(seq,),
|
||||
)
|
||||
|
||||
|
||||
@@ -57,7 +63,6 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict:
|
||||
|
||||
src = sqlite3.connect(src_path)
|
||||
src.row_factory = sqlite3.Row
|
||||
dst = connect_db()
|
||||
|
||||
stats: dict[str, int] = {}
|
||||
tables = _sqlite_tables(src)
|
||||
@@ -65,44 +70,43 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict:
|
||||
|
||||
try:
|
||||
for table in tables:
|
||||
pg_cols = _pg_columns(dst, table)
|
||||
if not pg_cols:
|
||||
print(f" 跳过 {table}(PostgreSQL 无此表,请先运行 deploy_postgres.sh 初始化)")
|
||||
continue
|
||||
src_cols = [c[1] for c in src.execute(f"PRAGMA table_info({table})").fetchall()]
|
||||
cols = [c for c in src_cols if c in pg_cols]
|
||||
if not cols:
|
||||
print(f" 跳过 {table}(无共同列)")
|
||||
continue
|
||||
rows = src.execute(f"SELECT {', '.join(cols)} FROM {table}").fetchall()
|
||||
if dry_run:
|
||||
stats[table] = len(rows)
|
||||
print(f" [dry-run] {table}: {len(rows)} 行")
|
||||
continue
|
||||
if not rows:
|
||||
stats[table] = 0
|
||||
continue
|
||||
tbl = connect_db()
|
||||
try:
|
||||
dst.execute(f"DELETE FROM {table}")
|
||||
pg_cols = _pg_columns(tbl, table)
|
||||
if not pg_cols:
|
||||
print(f" 跳过 {table}(PostgreSQL 无此表,请先运行 deploy_postgres.sh 初始化)")
|
||||
continue
|
||||
src_cols = [c[1] for c in src.execute(f"PRAGMA table_info({table})").fetchall()]
|
||||
cols = [c for c in src_cols if c in pg_cols]
|
||||
if not cols:
|
||||
print(f" 跳过 {table}(无共同列)")
|
||||
continue
|
||||
rows = src.execute(f"SELECT {', '.join(cols)} FROM {table}").fetchall()
|
||||
if dry_run:
|
||||
stats[table] = len(rows)
|
||||
print(f" [dry-run] {table}: {len(rows)} 行")
|
||||
continue
|
||||
if not rows:
|
||||
stats[table] = 0
|
||||
continue
|
||||
tbl.execute(f"DELETE FROM {table}")
|
||||
placeholders = ", ".join(["?"] * len(cols))
|
||||
col_sql = ", ".join(cols)
|
||||
insert_sql = f"INSERT INTO {table} ({col_sql}) VALUES ({placeholders})"
|
||||
for row in rows:
|
||||
dst.execute(insert_sql, tuple(row[c] for c in cols))
|
||||
tbl.execute(insert_sql, tuple(row[c] for c in cols))
|
||||
if "id" in cols:
|
||||
try:
|
||||
_reset_sequences(dst, table, "id")
|
||||
except Exception:
|
||||
rollback_if_postgres(dst)
|
||||
dst.commit()
|
||||
_reset_sequences(tbl, table, "id")
|
||||
tbl.commit()
|
||||
stats[table] = len(rows)
|
||||
print(f" {table}: {len(rows)} 行")
|
||||
except Exception as exc:
|
||||
rollback_if_postgres(dst)
|
||||
rollback_if_postgres(tbl)
|
||||
raise RuntimeError(f"迁移表 {table} 失败: {exc}") from exc
|
||||
finally:
|
||||
tbl.close()
|
||||
finally:
|
||||
src.close()
|
||||
dst.close()
|
||||
|
||||
total = sum(stats.values())
|
||||
print(f"==> 完成,共迁移 {total} 行")
|
||||
|
||||
Reference in New Issue
Block a user