#!/usr/bin/env python3 # Copyright (c) 2025-2026 马建军. All rights reserved. """将 SQLite futures.db 迁移到 PostgreSQL(需已 init 表结构,见 deploy_postgres.sh)。""" from __future__ import annotations import argparse import os import sqlite3 import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) try: import psycopg from psycopg.rows import dict_row except ImportError: psycopg = None # type: ignore[assignment] dict_row = None # type: ignore[assignment] from dotenv import load_dotenv load_dotenv(ROOT / ".env") from db_conn import DB_PATH, db_backend, is_postgres # noqa: E402 def _migrate_conn(): """迁移专用连接(autocommit,避免单表失败污染事务)。""" if not is_postgres(): raise RuntimeError("需要 PostgreSQL") url = (os.getenv("DATABASE_URL") or "").strip() raw = psycopg.connect(url, row_factory=dict_row, autocommit=True) try: with raw.cursor() as cur: cur.execute("SET TIME ZONE 'Asia/Shanghai'") except Exception: pass from db_conn import DbConnection return DbConnection("postgres", raw, from_pool=False) def _sqlite_tables(conn: sqlite3.Connection) -> list[str]: rows = conn.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name" ).fetchall() return [r[0] for r in rows] def _pg_columns(pg_conn, table: str) -> list[str]: rows = pg_conn.execute( """SELECT column_name FROM information_schema.columns WHERE table_schema='public' AND table_name=%s ORDER BY ordinal_position""", (table,), ).fetchall() return [r["column_name"] for r in rows] def _reset_sequences(pg_conn, table: str, pk: str = "id") -> None: row = pg_conn.execute( "SELECT pg_get_serial_sequence(%s, %s) AS seq", (table, pk) ).fetchone() seq = (row["seq"] if row else None) or None if not seq: return pg_conn.execute( f"SELECT setval(%s, COALESCE((SELECT MAX({pk}) FROM {table}), 1), true)", (seq,), ) def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict: if not is_postgres(): raise RuntimeError("请先配置 DATABASE_URL=postgresql://... 后再运行迁移") src_path = sqlite_path or DB_PATH if not os.path.isfile(src_path): raise FileNotFoundError(f"SQLite 源库不存在: {src_path}") print(f"==> 源库: {src_path}") print(f"==> 目标: PostgreSQL ({os.getenv('DATABASE_URL', '').split('@')[-1]})") src = sqlite3.connect(src_path) src.row_factory = sqlite3.Row stats: dict[str, int] = {} tables = _sqlite_tables(src) print(f"==> 共 {len(tables)} 张表: {', '.join(tables)}") try: for table in tables: tbl = _migrate_conn() try: pg_cols = _pg_columns(tbl, table) if not pg_cols: print(f" 跳过 {table}(PostgreSQL 无此表,请先运行 deploy_postgres.sh 初始化)") continue src_cols = [c[1] for c in src.execute(f"PRAGMA table_info({table})").fetchall()] cols = [c for c in src_cols if c in pg_cols] if not cols: print(f" 跳过 {table}(无共同列)") continue rows = src.execute(f"SELECT {', '.join(cols)} FROM {table}").fetchall() if dry_run: stats[table] = len(rows) print(f" [dry-run] {table}: {len(rows)} 行") continue if not rows: stats[table] = 0 continue tbl.execute(f"DELETE FROM {table}") placeholders = ", ".join(["?"] * len(cols)) col_sql = ", ".join(cols) insert_sql = f"INSERT INTO {table} ({col_sql}) VALUES ({placeholders})" for row in rows: tbl.execute(insert_sql, tuple(row[c] for c in cols)) if "id" in cols: _reset_sequences(tbl, table, "id") stats[table] = len(rows) print(f" {table}: {len(rows)} 行") except Exception as exc: raise RuntimeError(f"迁移表 {table} 失败: {exc}") from exc finally: tbl.close() finally: src.close() total = sum(stats.values()) print(f"==> 完成,共迁移 {total} 行") return stats def main() -> int: parser = argparse.ArgumentParser(description="SQLite -> PostgreSQL 数据迁移") parser.add_argument("--sqlite", default=DB_PATH, help=f"SQLite 路径,默认 {DB_PATH}") parser.add_argument("--dry-run", action="store_true", help="仅统计行数,不写入") args = parser.parse_args() if db_backend() != "postgres": print("错误: 未检测到 DATABASE_URL(postgresql://...)", file=sys.stderr) return 1 try: migrate(sqlite_path=args.sqlite, dry_run=args.dry_run) except Exception as exc: print(f"迁移失败: {exc}", file=sys.stderr) return 1 return 0 if __name__ == "__main__": raise SystemExit(main())