Files
qihuo/scripts/migrate_sqlite_to_postgres.py
T
2026-07-01 08:22:54 +08:00

131 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# Copyright (c) 2025-2026 马建军. All rights reserved.
"""将 SQLite futures.db 迁移到 PostgreSQL(需已 init 表结构,见 deploy_postgres.sh)。"""
from __future__ import annotations
import argparse
import os
import sqlite3
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from dotenv import load_dotenv
load_dotenv(ROOT / ".env")
from db_conn import DB_PATH, connect_db, db_backend, is_postgres, rollback_if_postgres # noqa: E402
def _sqlite_tables(conn: sqlite3.Connection) -> list[str]:
rows = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name"
).fetchall()
return [r[0] for r in rows]
def _pg_columns(pg_conn, table: str) -> list[str]:
rows = pg_conn.execute(
"""SELECT column_name FROM information_schema.columns
WHERE table_schema='public' AND table_name=%s
ORDER BY ordinal_position""",
(table,),
).fetchall()
return [r["column_name"] for r in rows]
def _reset_sequences(pg_conn, table: str, pk: str = "id") -> None:
pg_conn.execute(
f"SELECT setval(pg_get_serial_sequence('{table}', '{pk}'), "
f"COALESCE((SELECT MAX({pk}) FROM {table}), 1), true)"
)
def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict:
if not is_postgres():
raise RuntimeError("请先配置 DATABASE_URL=postgresql://... 后再运行迁移")
src_path = sqlite_path or DB_PATH
if not os.path.isfile(src_path):
raise FileNotFoundError(f"SQLite 源库不存在: {src_path}")
print(f"==> 源库: {src_path}")
print(f"==> 目标: PostgreSQL ({os.getenv('DATABASE_URL', '').split('@')[-1]})")
src = sqlite3.connect(src_path)
src.row_factory = sqlite3.Row
dst = connect_db()
stats: dict[str, int] = {}
tables = _sqlite_tables(src)
print(f"==> 共 {len(tables)} 张表: {', '.join(tables)}")
try:
for table in tables:
pg_cols = _pg_columns(dst, table)
if not pg_cols:
print(f" 跳过 {table}PostgreSQL 无此表,请先运行 deploy_postgres.sh 初始化)")
continue
src_cols = [c[1] for c in src.execute(f"PRAGMA table_info({table})").fetchall()]
cols = [c for c in src_cols if c in pg_cols]
if not cols:
print(f" 跳过 {table}(无共同列)")
continue
rows = src.execute(f"SELECT {', '.join(cols)} FROM {table}").fetchall()
if dry_run:
stats[table] = len(rows)
print(f" [dry-run] {table}: {len(rows)}")
continue
if not rows:
stats[table] = 0
continue
try:
dst.execute(f"DELETE FROM {table}")
placeholders = ", ".join(["?"] * len(cols))
col_sql = ", ".join(cols)
insert_sql = f"INSERT INTO {table} ({col_sql}) VALUES ({placeholders})"
for row in rows:
dst.execute(insert_sql, tuple(row[c] for c in cols))
if "id" in cols:
try:
_reset_sequences(dst, table, "id")
except Exception:
rollback_if_postgres(dst)
dst.commit()
stats[table] = len(rows)
print(f" {table}: {len(rows)}")
except Exception as exc:
rollback_if_postgres(dst)
raise RuntimeError(f"迁移表 {table} 失败: {exc}") from exc
finally:
src.close()
dst.close()
total = sum(stats.values())
print(f"==> 完成,共迁移 {total}")
return stats
def main() -> int:
parser = argparse.ArgumentParser(description="SQLite -> PostgreSQL 数据迁移")
parser.add_argument("--sqlite", default=DB_PATH, help=f"SQLite 路径,默认 {DB_PATH}")
parser.add_argument("--dry-run", action="store_true", help="仅统计行数,不写入")
args = parser.parse_args()
if db_backend() != "postgres":
print("错误: 未检测到 DATABASE_URLpostgresql://...", file=sys.stderr)
return 1
try:
migrate(sqlite_path=args.sqlite, dry_run=args.dry_run)
except Exception as exc:
print(f"迁移失败: {exc}", file=sys.stderr)
return 1
return 0
if __name__ == "__main__":
raise SystemExit(main())