6abe06d935
Co-authored-by: Cursor <cursoragent@cursor.com>
131 lines
4.5 KiB
Python
131 lines
4.5 KiB
Python
#!/usr/bin/env python3
|
||
# Copyright (c) 2025-2026 马建军. All rights reserved.
|
||
"""将 SQLite futures.db 迁移到 PostgreSQL(需已 init 表结构,见 deploy_postgres.sh)。"""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import os
|
||
import sqlite3
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
ROOT = Path(__file__).resolve().parents[1]
|
||
if str(ROOT) not in sys.path:
|
||
sys.path.insert(0, str(ROOT))
|
||
|
||
from dotenv import load_dotenv
|
||
|
||
load_dotenv(ROOT / ".env")
|
||
|
||
from db_conn import DB_PATH, connect_db, db_backend, is_postgres, rollback_if_postgres # noqa: E402
|
||
|
||
|
||
def _sqlite_tables(conn: sqlite3.Connection) -> list[str]:
|
||
rows = conn.execute(
|
||
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name"
|
||
).fetchall()
|
||
return [r[0] for r in rows]
|
||
|
||
|
||
def _pg_columns(pg_conn, table: str) -> list[str]:
|
||
rows = pg_conn.execute(
|
||
"""SELECT column_name FROM information_schema.columns
|
||
WHERE table_schema='public' AND table_name=%s
|
||
ORDER BY ordinal_position""",
|
||
(table,),
|
||
).fetchall()
|
||
return [r["column_name"] for r in rows]
|
||
|
||
|
||
def _reset_sequences(pg_conn, table: str, pk: str = "id") -> None:
|
||
pg_conn.execute(
|
||
f"SELECT setval(pg_get_serial_sequence('{table}', '{pk}'), "
|
||
f"COALESCE((SELECT MAX({pk}) FROM {table}), 1), true)"
|
||
)
|
||
|
||
|
||
def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict:
|
||
if not is_postgres():
|
||
raise RuntimeError("请先配置 DATABASE_URL=postgresql://... 后再运行迁移")
|
||
|
||
src_path = sqlite_path or DB_PATH
|
||
if not os.path.isfile(src_path):
|
||
raise FileNotFoundError(f"SQLite 源库不存在: {src_path}")
|
||
|
||
print(f"==> 源库: {src_path}")
|
||
print(f"==> 目标: PostgreSQL ({os.getenv('DATABASE_URL', '').split('@')[-1]})")
|
||
|
||
src = sqlite3.connect(src_path)
|
||
src.row_factory = sqlite3.Row
|
||
dst = connect_db()
|
||
|
||
stats: dict[str, int] = {}
|
||
tables = _sqlite_tables(src)
|
||
print(f"==> 共 {len(tables)} 张表: {', '.join(tables)}")
|
||
|
||
try:
|
||
for table in tables:
|
||
pg_cols = _pg_columns(dst, table)
|
||
if not pg_cols:
|
||
print(f" 跳过 {table}(PostgreSQL 无此表,请先运行 deploy_postgres.sh 初始化)")
|
||
continue
|
||
src_cols = [c[1] for c in src.execute(f"PRAGMA table_info({table})").fetchall()]
|
||
cols = [c for c in src_cols if c in pg_cols]
|
||
if not cols:
|
||
print(f" 跳过 {table}(无共同列)")
|
||
continue
|
||
rows = src.execute(f"SELECT {', '.join(cols)} FROM {table}").fetchall()
|
||
if dry_run:
|
||
stats[table] = len(rows)
|
||
print(f" [dry-run] {table}: {len(rows)} 行")
|
||
continue
|
||
if not rows:
|
||
stats[table] = 0
|
||
continue
|
||
try:
|
||
dst.execute(f"DELETE FROM {table}")
|
||
placeholders = ", ".join(["?"] * len(cols))
|
||
col_sql = ", ".join(cols)
|
||
insert_sql = f"INSERT INTO {table} ({col_sql}) VALUES ({placeholders})"
|
||
for row in rows:
|
||
dst.execute(insert_sql, tuple(row[c] for c in cols))
|
||
if "id" in cols:
|
||
try:
|
||
_reset_sequences(dst, table, "id")
|
||
except Exception:
|
||
rollback_if_postgres(dst)
|
||
dst.commit()
|
||
stats[table] = len(rows)
|
||
print(f" {table}: {len(rows)} 行")
|
||
except Exception as exc:
|
||
rollback_if_postgres(dst)
|
||
raise RuntimeError(f"迁移表 {table} 失败: {exc}") from exc
|
||
finally:
|
||
src.close()
|
||
dst.close()
|
||
|
||
total = sum(stats.values())
|
||
print(f"==> 完成,共迁移 {total} 行")
|
||
return stats
|
||
|
||
|
||
def main() -> int:
|
||
parser = argparse.ArgumentParser(description="SQLite -> PostgreSQL 数据迁移")
|
||
parser.add_argument("--sqlite", default=DB_PATH, help=f"SQLite 路径,默认 {DB_PATH}")
|
||
parser.add_argument("--dry-run", action="store_true", help="仅统计行数,不写入")
|
||
args = parser.parse_args()
|
||
|
||
if db_backend() != "postgres":
|
||
print("错误: 未检测到 DATABASE_URL(postgresql://...)", file=sys.stderr)
|
||
return 1
|
||
try:
|
||
migrate(sqlite_path=args.sqlite, dry_run=args.dry_run)
|
||
except Exception as exc:
|
||
print(f"迁移失败: {exc}", file=sys.stderr)
|
||
return 1
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|