Inline account_risk_state DDL in init_db; use autocommit for migration.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-07-01 08:26:13 +08:00
parent 06fbff04a7
commit d7ea7b9e8a
2 changed files with 43 additions and 4 deletions
+25 -4
View File
@@ -13,11 +13,34 @@ ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
try:
import psycopg
from psycopg.rows import dict_row
except ImportError:
psycopg = None # type: ignore[assignment]
dict_row = None # type: ignore[assignment]
from dotenv import load_dotenv
load_dotenv(ROOT / ".env")
from db_conn import DB_PATH, connect_db, db_backend, is_postgres, rollback_if_postgres # noqa: E402
from db_conn import DB_PATH, db_backend, is_postgres # noqa: E402
def _migrate_conn():
"""迁移专用连接(autocommit,避免单表失败污染事务)。"""
if not is_postgres():
raise RuntimeError("需要 PostgreSQL")
url = (os.getenv("DATABASE_URL") or "").strip()
raw = psycopg.connect(url, row_factory=dict_row, autocommit=True)
try:
with raw.cursor() as cur:
cur.execute("SET TIME ZONE 'Asia/Shanghai'")
except Exception:
pass
from db_conn import DbConnection
return DbConnection("postgres", raw, from_pool=False)
def _sqlite_tables(conn: sqlite3.Connection) -> list[str]:
@@ -70,7 +93,7 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict:
try:
for table in tables:
tbl = connect_db()
tbl = _migrate_conn()
try:
pg_cols = _pg_columns(tbl, table)
if not pg_cols:
@@ -97,11 +120,9 @@ def migrate(*, sqlite_path: str | None = None, dry_run: bool = False) -> dict:
tbl.execute(insert_sql, tuple(row[c] for c in cols))
if "id" in cols:
_reset_sequences(tbl, table, "id")
tbl.commit()
stats[table] = len(rows)
print(f" {table}: {len(rows)}")
except Exception as exc:
rollback_if_postgres(tbl)
raise RuntimeError(f"迁移表 {table} 失败: {exc}") from exc
finally:
tbl.close()