Fix migration: per-table connections and skip non-serial setval.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -38,9 +38,15 @@ def _pg_columns(pg_conn, table: str) -> list[str]:
|
|||||||
|
|
||||||
|
|
||||||
def _reset_sequences(pg_conn, table: str, pk: str = "id") -> None:
|
def _reset_sequences(pg_conn, table: str, pk: str = "id") -> None:
|
||||||
|
row = pg_conn.execute(
|
||||||
|
"SELECT pg_get_serial_sequence(%s, %s) AS seq", (table, pk)
|
||||||
|
).fetchone()
|
||||||
|
seq = (row["seq"] if row else None) or None
|
||||||
|
if not seq:
|
||||||
|
return
|
||||||
pg_conn.execute(
|
pg_conn.execute(
|
||||||
f"SELECT setval(pg_get_serial_sequence('{table}', '{pk}'), "
|
f"SELECT setval(%s, COALESCE((SELECT MAX({pk}) FROM {table}), 1), true)",
|
||||||
f"COALESCE((SELECT MAX({pk}) FROM {table}), 1), true)"
|
(seq,),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -57,7 +63,6 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict:
|
|||||||
|
|
||||||
src = sqlite3.connect(src_path)
|
src = sqlite3.connect(src_path)
|
||||||
src.row_factory = sqlite3.Row
|
src.row_factory = sqlite3.Row
|
||||||
dst = connect_db()
|
|
||||||
|
|
||||||
stats: dict[str, int] = {}
|
stats: dict[str, int] = {}
|
||||||
tables = _sqlite_tables(src)
|
tables = _sqlite_tables(src)
|
||||||
@@ -65,7 +70,9 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
for table in tables:
|
for table in tables:
|
||||||
pg_cols = _pg_columns(dst, table)
|
tbl = connect_db()
|
||||||
|
try:
|
||||||
|
pg_cols = _pg_columns(tbl, table)
|
||||||
if not pg_cols:
|
if not pg_cols:
|
||||||
print(f" 跳过 {table}(PostgreSQL 无此表,请先运行 deploy_postgres.sh 初始化)")
|
print(f" 跳过 {table}(PostgreSQL 无此表,请先运行 deploy_postgres.sh 初始化)")
|
||||||
continue
|
continue
|
||||||
@@ -82,27 +89,24 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict:
|
|||||||
if not rows:
|
if not rows:
|
||||||
stats[table] = 0
|
stats[table] = 0
|
||||||
continue
|
continue
|
||||||
try:
|
tbl.execute(f"DELETE FROM {table}")
|
||||||
dst.execute(f"DELETE FROM {table}")
|
|
||||||
placeholders = ", ".join(["?"] * len(cols))
|
placeholders = ", ".join(["?"] * len(cols))
|
||||||
col_sql = ", ".join(cols)
|
col_sql = ", ".join(cols)
|
||||||
insert_sql = f"INSERT INTO {table} ({col_sql}) VALUES ({placeholders})"
|
insert_sql = f"INSERT INTO {table} ({col_sql}) VALUES ({placeholders})"
|
||||||
for row in rows:
|
for row in rows:
|
||||||
dst.execute(insert_sql, tuple(row[c] for c in cols))
|
tbl.execute(insert_sql, tuple(row[c] for c in cols))
|
||||||
if "id" in cols:
|
if "id" in cols:
|
||||||
try:
|
_reset_sequences(tbl, table, "id")
|
||||||
_reset_sequences(dst, table, "id")
|
tbl.commit()
|
||||||
except Exception:
|
|
||||||
rollback_if_postgres(dst)
|
|
||||||
dst.commit()
|
|
||||||
stats[table] = len(rows)
|
stats[table] = len(rows)
|
||||||
print(f" {table}: {len(rows)} 行")
|
print(f" {table}: {len(rows)} 行")
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
rollback_if_postgres(dst)
|
rollback_if_postgres(tbl)
|
||||||
raise RuntimeError(f"迁移表 {table} 失败: {exc}") from exc
|
raise RuntimeError(f"迁移表 {table} 失败: {exc}") from exc
|
||||||
|
finally:
|
||||||
|
tbl.close()
|
||||||
finally:
|
finally:
|
||||||
src.close()
|
src.close()
|
||||||
dst.close()
|
|
||||||
|
|
||||||
total = sum(stats.values())
|
total = sum(stats.values())
|
||||||
print(f"==> 完成,共迁移 {total} 行")
|
print(f"==> 完成,共迁移 {total} 行")
|
||||||
|
|||||||
Reference in New Issue
Block a user