52aca456e9
Support DATABASE_URL with connection pooling, pg_dump backups, SQLite migration script, and deploy_postgres.sh with docs. Co-authored-by: Cursor <cursoragent@cursor.com>
347 lines
11 KiB
Python
347 lines
11 KiB
Python
# Copyright (c) 2025-2026 马建军. All rights reserved.
|
||
# 专有软件 — 未经授权禁止复制、传播、转售。
|
||
# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。
|
||
# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md
|
||
|
||
"""数据库连接:开发默认 SQLite,生产推荐 PostgreSQL(DATABASE_URL)。"""
|
||
from __future__ import annotations
|
||
|
||
import os
|
||
import re
|
||
import sqlite3
|
||
import threading
|
||
import time
|
||
from typing import Any, Iterable, Optional, Sequence
|
||
|
||
DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "futures.db")
|
||
|
||
_backend_lock = threading.Lock()
|
||
_backend: Optional[str] = None
|
||
_pg_pool = None
|
||
_pg_pool_lock = threading.Lock()
|
||
|
||
try:
|
||
import psycopg
|
||
from psycopg import OperationalError as PgOperationalError
|
||
from psycopg import IntegrityError as PgIntegrityError
|
||
from psycopg.rows import dict_row
|
||
from psycopg_pool import ConnectionPool
|
||
|
||
_PSYCOPG_OK = True
|
||
except ImportError:
|
||
psycopg = None # type: ignore[assignment]
|
||
PgOperationalError = Exception # type: ignore[misc,assignment]
|
||
PgIntegrityError = Exception # type: ignore[misc,assignment]
|
||
dict_row = None # type: ignore[assignment]
|
||
ConnectionPool = None # type: ignore[misc,assignment]
|
||
_PSYCOPG_OK = False
|
||
|
||
OperationalError = sqlite3.OperationalError
|
||
IntegrityError = sqlite3.IntegrityError
|
||
|
||
|
||
def db_backend() -> str:
|
||
"""``sqlite`` 或 ``postgres``。"""
|
||
global _backend
|
||
if _backend is not None:
|
||
return _backend
|
||
with _backend_lock:
|
||
if _backend is not None:
|
||
return _backend
|
||
url = (os.getenv("DATABASE_URL") or "").strip()
|
||
if url.startswith(("postgresql://", "postgres://")):
|
||
if not _PSYCOPG_OK:
|
||
raise RuntimeError(
|
||
"已配置 DATABASE_URL 但未安装 psycopg,请执行: pip install 'psycopg[binary]' psycopg-pool"
|
||
)
|
||
_backend = "postgres"
|
||
else:
|
||
_backend = "sqlite"
|
||
return _backend
|
||
|
||
|
||
def is_postgres() -> bool:
|
||
return db_backend() == "postgres"
|
||
|
||
|
||
def database_label() -> str:
|
||
if is_postgres():
|
||
url = (os.getenv("DATABASE_URL") or "").strip()
|
||
host = url.split("@")[-1].split("/")[0] if "@" in url else "postgresql"
|
||
return f"PostgreSQL ({host})"
|
||
return f"SQLite ({DB_PATH})"
|
||
|
||
|
||
def adapt_sql(sql: str) -> str:
|
||
"""将 SQLite 风格 SQL 适配为当前后端。"""
|
||
if not is_postgres():
|
||
return sql
|
||
out = sql
|
||
out = re.sub(
|
||
r"\bINTEGER PRIMARY KEY AUTOINCREMENT\b",
|
||
"SERIAL PRIMARY KEY",
|
||
out,
|
||
flags=re.IGNORECASE,
|
||
)
|
||
out = re.sub(r"\bAUTOINCREMENT\b", "", out, flags=re.IGNORECASE)
|
||
if "?" in out:
|
||
out = out.replace("?", "%s")
|
||
return out
|
||
|
||
|
||
def is_benign_migration_error(exc: BaseException) -> bool:
|
||
"""ALTER TABLE 重复列等初始化迁移可忽略的错误。"""
|
||
msg = str(exc).lower()
|
||
if any(
|
||
x in msg
|
||
for x in (
|
||
"duplicate column",
|
||
"already exists",
|
||
"duplicate key",
|
||
)
|
||
):
|
||
return True
|
||
if isinstance(exc, sqlite3.OperationalError) and "duplicate column" in msg:
|
||
return True
|
||
if _PSYCOPG_OK and isinstance(exc, PgOperationalError):
|
||
code = getattr(exc, "sqlstate", "") or ""
|
||
if code in ("42701", "42P07"): # duplicate_column, duplicate_table
|
||
return True
|
||
return False
|
||
|
||
|
||
class DbCursor:
|
||
"""统一 cursor:兼容 sqlite3 的 execute / fetchone / lastrowid。"""
|
||
|
||
def __init__(self, backend: str, raw_cursor: Any, raw_conn: Any) -> None:
|
||
self._backend = backend
|
||
self._cur = raw_cursor
|
||
self._conn = raw_conn
|
||
self.lastrowid: Optional[int] = None
|
||
self.rowcount: int = 0
|
||
|
||
def execute(self, sql: str, params: Sequence[Any] | None = None) -> "DbCursor":
|
||
sql = adapt_sql(sql)
|
||
params = params or ()
|
||
self._cur.execute(sql, params)
|
||
self.rowcount = int(getattr(self._cur, "rowcount", 0) or 0)
|
||
self.lastrowid = getattr(self._cur, "lastrowid", None)
|
||
if self.lastrowid is None and is_postgres():
|
||
if re.match(r"^\s*INSERT\b", sql, re.IGNORECASE):
|
||
try:
|
||
row = self._cur.fetchone()
|
||
if row is not None:
|
||
if isinstance(row, dict):
|
||
self.lastrowid = int(row.get("id") or row.get("Id") or 0) or None
|
||
else:
|
||
self.lastrowid = int(row[0])
|
||
except Exception:
|
||
try:
|
||
self._cur.execute("SELECT lastval()")
|
||
lv = self._cur.fetchone()
|
||
if lv:
|
||
self.lastrowid = int(lv[0] if not isinstance(lv, dict) else lv["lastval"])
|
||
except Exception:
|
||
pass
|
||
return self
|
||
|
||
def fetchone(self) -> Any:
|
||
return self._cur.fetchone()
|
||
|
||
def fetchall(self) -> list[Any]:
|
||
return self._cur.fetchall()
|
||
|
||
def close(self) -> None:
|
||
try:
|
||
self._cur.close()
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
class DbConnection:
|
||
"""统一连接:execute / commit / close,接口对齐 sqlite3.Connection。"""
|
||
|
||
def __init__(
|
||
self,
|
||
backend: str,
|
||
raw_conn: Any,
|
||
*,
|
||
from_pool: bool = False,
|
||
) -> None:
|
||
self._backend = backend
|
||
self._conn = raw_conn
|
||
self._from_pool = from_pool
|
||
self.row_factory = None
|
||
|
||
def execute(self, sql: str, params: Sequence[Any] | None = None) -> DbCursor:
|
||
cur = self.cursor()
|
||
return cur.execute(sql, params)
|
||
|
||
def cursor(self) -> DbCursor:
|
||
if self._backend == "sqlite":
|
||
return DbCursor(self._backend, self._conn.cursor(), self._conn)
|
||
raw = self._conn.cursor(row_factory=dict_row)
|
||
return DbCursor(self._backend, raw, self._conn)
|
||
|
||
def commit(self) -> None:
|
||
self._conn.commit()
|
||
|
||
def rollback(self) -> None:
|
||
self._conn.rollback()
|
||
|
||
def close(self) -> None:
|
||
if self._backend == "postgres" and self._from_pool:
|
||
try:
|
||
self._conn.rollback()
|
||
except Exception:
|
||
pass
|
||
try:
|
||
self._conn.close()
|
||
except Exception:
|
||
pass
|
||
return
|
||
try:
|
||
self._conn.close()
|
||
except Exception:
|
||
pass
|
||
|
||
def __enter__(self) -> "DbConnection":
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc, tb) -> None:
|
||
if exc:
|
||
try:
|
||
self.rollback()
|
||
except Exception:
|
||
pass
|
||
else:
|
||
try:
|
||
self.commit()
|
||
except Exception:
|
||
pass
|
||
self.close()
|
||
|
||
|
||
def _pg_pool_instance() -> ConnectionPool:
|
||
global _pg_pool
|
||
if _pg_pool is not None:
|
||
return _pg_pool
|
||
with _pg_pool_lock:
|
||
if _pg_pool is not None:
|
||
return _pg_pool
|
||
url = (os.getenv("DATABASE_URL") or "").strip()
|
||
min_size = max(1, int(os.getenv("PG_POOL_MIN", "2") or 2))
|
||
max_size = max(min_size, int(os.getenv("PG_POOL_MAX", "20") or 20))
|
||
_pg_pool = ConnectionPool(
|
||
conninfo=url,
|
||
min_size=min_size,
|
||
max_size=max_size,
|
||
kwargs={"row_factory": dict_row},
|
||
open=True,
|
||
)
|
||
return _pg_pool
|
||
|
||
|
||
def connect_db(path: str | None = None) -> DbConnection:
|
||
"""获取数据库连接。PostgreSQL 使用连接池;SQLite 每次新建连接(WAL)。"""
|
||
if is_postgres():
|
||
pool = _pg_pool_instance()
|
||
raw = pool.getconn()
|
||
try:
|
||
with raw.cursor() as cur:
|
||
cur.execute("SET TIME ZONE 'Asia/Shanghai'")
|
||
raw.commit()
|
||
except Exception:
|
||
pass
|
||
return DbConnection("postgres", raw, from_pool=True)
|
||
|
||
db_path = path or DB_PATH
|
||
raw = sqlite3.connect(db_path, timeout=30, check_same_thread=False)
|
||
raw.row_factory = sqlite3.Row
|
||
raw.execute("PRAGMA busy_timeout=30000")
|
||
try:
|
||
raw.execute("PRAGMA journal_mode=WAL")
|
||
except sqlite3.OperationalError:
|
||
pass
|
||
return DbConnection("sqlite", raw)
|
||
|
||
|
||
def close_pg_pool() -> None:
|
||
global _pg_pool
|
||
with _pg_pool_lock:
|
||
if _pg_pool is not None:
|
||
_pg_pool.close()
|
||
_pg_pool = None
|
||
|
||
|
||
def execute_retry(
|
||
conn: DbConnection,
|
||
sql: str,
|
||
params: tuple = (),
|
||
*,
|
||
retries: int = 6,
|
||
base_delay: float = 0.05,
|
||
) -> DbCursor:
|
||
"""遇锁冲突时短暂退避重试(SQLite locked / PG serialization)。"""
|
||
last_exc: Exception | None = None
|
||
for attempt in range(retries):
|
||
try:
|
||
return conn.execute(sql, params)
|
||
except (OperationalError, PgOperationalError) as exc:
|
||
msg = str(exc).lower()
|
||
retryable = "locked" in msg or "serialize" in msg or "deadlock" in msg
|
||
if not retryable:
|
||
raise
|
||
last_exc = exc
|
||
if attempt < retries - 1:
|
||
time.sleep(base_delay * (attempt + 1))
|
||
if last_exc:
|
||
raise last_exc
|
||
raise OperationalError("database is locked")
|
||
|
||
|
||
def commit_retry(
|
||
conn: DbConnection,
|
||
*,
|
||
retries: int = 6,
|
||
base_delay: float = 0.05,
|
||
) -> None:
|
||
"""遇锁冲突时短暂退避重试 commit。"""
|
||
last_exc: Exception | None = None
|
||
for attempt in range(retries):
|
||
try:
|
||
conn.commit()
|
||
return
|
||
except (OperationalError, PgOperationalError) as exc:
|
||
msg = str(exc).lower()
|
||
retryable = "locked" in msg or "serialize" in msg or "deadlock" in msg
|
||
if not retryable:
|
||
raise
|
||
last_exc = exc
|
||
if attempt < retries - 1:
|
||
time.sleep(base_delay * (attempt + 1))
|
||
if last_exc:
|
||
raise last_exc
|
||
raise OperationalError("database is locked")
|
||
|
||
|
||
def is_db_contention_error(exc: BaseException) -> bool:
|
||
"""SQLite locked / PostgreSQL serialization / deadlock。"""
|
||
msg = str(exc).lower()
|
||
if isinstance(exc, sqlite3.OperationalError):
|
||
return "locked" in msg
|
||
if _PSYCOPG_OK and isinstance(exc, PgOperationalError):
|
||
code = getattr(exc, "sqlstate", "") or ""
|
||
if code in ("40001", "40P01", "55P03"):
|
||
return True
|
||
return any(x in msg for x in ("deadlock", "serialize", "lock"))
|
||
return False
|
||
|
||
|
||
def reset_backend_for_tests(backend: str | None = None) -> None:
|
||
"""测试用:重置后端检测与连接池。"""
|
||
global _backend, _pg_pool
|
||
close_pg_pool()
|
||
with _backend_lock:
|
||
_backend = backend
|