Fix SQLite migration: no app import, commit per table.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright (c) 2025-2026 马建军. All rights reserved.
|
# Copyright (c) 2025-2026 马建军. All rights reserved.
|
||||||
"""将 SQLite futures.db 迁移到 PostgreSQL(需已配置 DATABASE_URL 并 init 空库)。"""
|
"""将 SQLite futures.db 迁移到 PostgreSQL(需已 init 表结构,见 deploy_postgres.sh)。"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -17,7 +17,7 @@ from dotenv import load_dotenv
|
|||||||
|
|
||||||
load_dotenv(ROOT / ".env")
|
load_dotenv(ROOT / ".env")
|
||||||
|
|
||||||
from db_conn import DB_PATH, connect_db, db_backend, is_postgres # noqa: E402
|
from db_conn import DB_PATH, connect_db, db_backend, is_postgres, rollback_if_postgres # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
def _sqlite_tables(conn: sqlite3.Connection) -> list[str]:
|
def _sqlite_tables(conn: sqlite3.Connection) -> list[str]:
|
||||||
@@ -40,14 +40,12 @@ def _pg_columns(pg_conn, table: str) -> list[str]:
|
|||||||
def _reset_sequences(pg_conn, table: str, pk: str = "id") -> None:
|
def _reset_sequences(pg_conn, table: str, pk: str = "id") -> None:
|
||||||
try:
|
try:
|
||||||
pg_conn.execute(
|
pg_conn.execute(
|
||||||
f"""SELECT setval(
|
f"SELECT setval(pg_get_serial_sequence('{table}', '{pk}'), "
|
||||||
pg_get_serial_sequence('{table}', '{pk}'),
|
f"COALESCE((SELECT MAX({pk}) FROM {table}), 1), true)"
|
||||||
COALESCE((SELECT MAX({pk}) FROM {table}), 1),
|
|
||||||
true
|
|
||||||
)"""
|
|
||||||
)
|
)
|
||||||
|
pg_conn.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
rollback_if_postgres(pg_conn)
|
||||||
|
|
||||||
|
|
||||||
def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict:
|
def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict:
|
||||||
@@ -61,12 +59,6 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict:
|
|||||||
print(f"==> 源库: {src_path}")
|
print(f"==> 源库: {src_path}")
|
||||||
print(f"==> 目标: PostgreSQL ({os.getenv('DATABASE_URL', '').split('@')[-1]})")
|
print(f"==> 目标: PostgreSQL ({os.getenv('DATABASE_URL', '').split('@')[-1]})")
|
||||||
|
|
||||||
if not dry_run:
|
|
||||||
print("==> 初始化 PostgreSQL 表结构...")
|
|
||||||
from app import init_db
|
|
||||||
|
|
||||||
init_db()
|
|
||||||
|
|
||||||
src = sqlite3.connect(src_path)
|
src = sqlite3.connect(src_path)
|
||||||
src.row_factory = sqlite3.Row
|
src.row_factory = sqlite3.Row
|
||||||
dst = connect_db()
|
dst = connect_db()
|
||||||
@@ -79,7 +71,7 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict:
|
|||||||
for table in tables:
|
for table in tables:
|
||||||
pg_cols = _pg_columns(dst, table)
|
pg_cols = _pg_columns(dst, table)
|
||||||
if not pg_cols:
|
if not pg_cols:
|
||||||
print(f" 跳过 {table}(PostgreSQL 无此表,请先 init_db)")
|
print(f" 跳过 {table}(PostgreSQL 无此表,请先运行 deploy_postgres.sh 初始化)")
|
||||||
continue
|
continue
|
||||||
src_cols = [c[1] for c in src.execute(f"PRAGMA table_info({table})").fetchall()]
|
src_cols = [c[1] for c in src.execute(f"PRAGMA table_info({table})").fetchall()]
|
||||||
cols = [c for c in src_cols if c in pg_cols]
|
cols = [c for c in src_cols if c in pg_cols]
|
||||||
@@ -94,18 +86,22 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict:
|
|||||||
if not rows:
|
if not rows:
|
||||||
stats[table] = 0
|
stats[table] = 0
|
||||||
continue
|
continue
|
||||||
dst.execute(f"DELETE FROM {table}")
|
try:
|
||||||
placeholders = ", ".join(["?"] * len(cols))
|
dst.execute(f"DELETE FROM {table}")
|
||||||
col_sql = ", ".join(cols)
|
placeholders = ", ".join(["?"] * len(cols))
|
||||||
insert_sql = f"INSERT INTO {table} ({col_sql}) VALUES ({placeholders})"
|
col_sql = ", ".join(cols)
|
||||||
for row in rows:
|
insert_sql = f"INSERT INTO {table} ({col_sql}) VALUES ({placeholders})"
|
||||||
dst.execute(insert_sql, tuple(row[c] for c in cols))
|
for row in rows:
|
||||||
stats[table] = len(rows)
|
dst.execute(insert_sql, tuple(row[c] for c in cols))
|
||||||
if "id" in cols:
|
if "id" in cols:
|
||||||
_reset_sequences(dst, table, "id")
|
_reset_sequences(dst, table, "id")
|
||||||
print(f" {table}: {len(rows)} 行")
|
else:
|
||||||
if not dry_run:
|
dst.commit()
|
||||||
dst.commit()
|
stats[table] = len(rows)
|
||||||
|
print(f" {table}: {len(rows)} 行")
|
||||||
|
except Exception as exc:
|
||||||
|
rollback_if_postgres(dst)
|
||||||
|
raise RuntimeError(f"迁移表 {table} 失败: {exc}") from exc
|
||||||
finally:
|
finally:
|
||||||
src.close()
|
src.close()
|
||||||
dst.close()
|
dst.close()
|
||||||
|
|||||||
Reference in New Issue
Block a user