Add PostgreSQL production backend to eliminate SQLite lock contention.
Support DATABASE_URL with connection pooling, pg_dump backups, SQLite migration script, and deploy_postgres.sh with docs. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
+292
-18
@@ -3,70 +3,344 @@
|
||||
# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。
|
||||
# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md
|
||||
|
||||
"""SQLite 连接统一配置(WAL + busy_timeout,降低并发锁冲突)。"""
|
||||
"""数据库连接:开发默认 SQLite,生产推荐 PostgreSQL(DATABASE_URL)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Iterable, Optional, Sequence
|
||||
|
||||
DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "futures.db")
|
||||
|
||||
_backend_lock = threading.Lock()
|
||||
_backend: Optional[str] = None
|
||||
_pg_pool = None
|
||||
_pg_pool_lock = threading.Lock()
|
||||
|
||||
try:
|
||||
import psycopg
|
||||
from psycopg import OperationalError as PgOperationalError
|
||||
from psycopg import IntegrityError as PgIntegrityError
|
||||
from psycopg.rows import dict_row
|
||||
from psycopg_pool import ConnectionPool
|
||||
|
||||
_PSYCOPG_OK = True
|
||||
except ImportError:
|
||||
psycopg = None # type: ignore[assignment]
|
||||
PgOperationalError = Exception # type: ignore[misc,assignment]
|
||||
PgIntegrityError = Exception # type: ignore[misc,assignment]
|
||||
dict_row = None # type: ignore[assignment]
|
||||
ConnectionPool = None # type: ignore[misc,assignment]
|
||||
_PSYCOPG_OK = False
|
||||
|
||||
OperationalError = sqlite3.OperationalError
|
||||
IntegrityError = sqlite3.IntegrityError
|
||||
|
||||
|
||||
def db_backend() -> str:
|
||||
"""``sqlite`` 或 ``postgres``。"""
|
||||
global _backend
|
||||
if _backend is not None:
|
||||
return _backend
|
||||
with _backend_lock:
|
||||
if _backend is not None:
|
||||
return _backend
|
||||
url = (os.getenv("DATABASE_URL") or "").strip()
|
||||
if url.startswith(("postgresql://", "postgres://")):
|
||||
if not _PSYCOPG_OK:
|
||||
raise RuntimeError(
|
||||
"已配置 DATABASE_URL 但未安装 psycopg,请执行: pip install 'psycopg[binary]' psycopg-pool"
|
||||
)
|
||||
_backend = "postgres"
|
||||
else:
|
||||
_backend = "sqlite"
|
||||
return _backend
|
||||
|
||||
|
||||
def is_postgres() -> bool:
|
||||
return db_backend() == "postgres"
|
||||
|
||||
|
||||
def database_label() -> str:
|
||||
if is_postgres():
|
||||
url = (os.getenv("DATABASE_URL") or "").strip()
|
||||
host = url.split("@")[-1].split("/")[0] if "@" in url else "postgresql"
|
||||
return f"PostgreSQL ({host})"
|
||||
return f"SQLite ({DB_PATH})"
|
||||
|
||||
|
||||
def adapt_sql(sql: str) -> str:
|
||||
"""将 SQLite 风格 SQL 适配为当前后端。"""
|
||||
if not is_postgres():
|
||||
return sql
|
||||
out = sql
|
||||
out = re.sub(
|
||||
r"\bINTEGER PRIMARY KEY AUTOINCREMENT\b",
|
||||
"SERIAL PRIMARY KEY",
|
||||
out,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
out = re.sub(r"\bAUTOINCREMENT\b", "", out, flags=re.IGNORECASE)
|
||||
if "?" in out:
|
||||
out = out.replace("?", "%s")
|
||||
return out
|
||||
|
||||
|
||||
def is_benign_migration_error(exc: BaseException) -> bool:
|
||||
"""ALTER TABLE 重复列等初始化迁移可忽略的错误。"""
|
||||
msg = str(exc).lower()
|
||||
if any(
|
||||
x in msg
|
||||
for x in (
|
||||
"duplicate column",
|
||||
"already exists",
|
||||
"duplicate key",
|
||||
)
|
||||
):
|
||||
return True
|
||||
if isinstance(exc, sqlite3.OperationalError) and "duplicate column" in msg:
|
||||
return True
|
||||
if _PSYCOPG_OK and isinstance(exc, PgOperationalError):
|
||||
code = getattr(exc, "sqlstate", "") or ""
|
||||
if code in ("42701", "42P07"): # duplicate_column, duplicate_table
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class DbCursor:
|
||||
"""统一 cursor:兼容 sqlite3 的 execute / fetchone / lastrowid。"""
|
||||
|
||||
def __init__(self, backend: str, raw_cursor: Any, raw_conn: Any) -> None:
|
||||
self._backend = backend
|
||||
self._cur = raw_cursor
|
||||
self._conn = raw_conn
|
||||
self.lastrowid: Optional[int] = None
|
||||
self.rowcount: int = 0
|
||||
|
||||
def execute(self, sql: str, params: Sequence[Any] | None = None) -> "DbCursor":
|
||||
sql = adapt_sql(sql)
|
||||
params = params or ()
|
||||
self._cur.execute(sql, params)
|
||||
self.rowcount = int(getattr(self._cur, "rowcount", 0) or 0)
|
||||
self.lastrowid = getattr(self._cur, "lastrowid", None)
|
||||
if self.lastrowid is None and is_postgres():
|
||||
if re.match(r"^\s*INSERT\b", sql, re.IGNORECASE):
|
||||
try:
|
||||
row = self._cur.fetchone()
|
||||
if row is not None:
|
||||
if isinstance(row, dict):
|
||||
self.lastrowid = int(row.get("id") or row.get("Id") or 0) or None
|
||||
else:
|
||||
self.lastrowid = int(row[0])
|
||||
except Exception:
|
||||
try:
|
||||
self._cur.execute("SELECT lastval()")
|
||||
lv = self._cur.fetchone()
|
||||
if lv:
|
||||
self.lastrowid = int(lv[0] if not isinstance(lv, dict) else lv["lastval"])
|
||||
except Exception:
|
||||
pass
|
||||
return self
|
||||
|
||||
def fetchone(self) -> Any:
|
||||
return self._cur.fetchone()
|
||||
|
||||
def fetchall(self) -> list[Any]:
|
||||
return self._cur.fetchall()
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
self._cur.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class DbConnection:
|
||||
"""统一连接:execute / commit / close,接口对齐 sqlite3.Connection。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend: str,
|
||||
raw_conn: Any,
|
||||
*,
|
||||
from_pool: bool = False,
|
||||
) -> None:
|
||||
self._backend = backend
|
||||
self._conn = raw_conn
|
||||
self._from_pool = from_pool
|
||||
self.row_factory = None
|
||||
|
||||
def execute(self, sql: str, params: Sequence[Any] | None = None) -> DbCursor:
|
||||
cur = self.cursor()
|
||||
return cur.execute(sql, params)
|
||||
|
||||
def cursor(self) -> DbCursor:
|
||||
if self._backend == "sqlite":
|
||||
return DbCursor(self._backend, self._conn.cursor(), self._conn)
|
||||
raw = self._conn.cursor(row_factory=dict_row)
|
||||
return DbCursor(self._backend, raw, self._conn)
|
||||
|
||||
def commit(self) -> None:
|
||||
self._conn.commit()
|
||||
|
||||
def rollback(self) -> None:
|
||||
self._conn.rollback()
|
||||
|
||||
def close(self) -> None:
|
||||
if self._backend == "postgres" and self._from_pool:
|
||||
try:
|
||||
self._conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
try:
|
||||
self._conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def __enter__(self) -> "DbConnection":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
if exc:
|
||||
try:
|
||||
self.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
self.commit()
|
||||
except Exception:
|
||||
pass
|
||||
self.close()
|
||||
|
||||
|
||||
def _pg_pool_instance() -> ConnectionPool:
|
||||
global _pg_pool
|
||||
if _pg_pool is not None:
|
||||
return _pg_pool
|
||||
with _pg_pool_lock:
|
||||
if _pg_pool is not None:
|
||||
return _pg_pool
|
||||
url = (os.getenv("DATABASE_URL") or "").strip()
|
||||
min_size = max(1, int(os.getenv("PG_POOL_MIN", "2") or 2))
|
||||
max_size = max(min_size, int(os.getenv("PG_POOL_MAX", "20") or 20))
|
||||
_pg_pool = ConnectionPool(
|
||||
conninfo=url,
|
||||
min_size=min_size,
|
||||
max_size=max_size,
|
||||
kwargs={"row_factory": dict_row},
|
||||
open=True,
|
||||
)
|
||||
return _pg_pool
|
||||
|
||||
|
||||
def connect_db(path: str | None = None) -> DbConnection:
|
||||
"""获取数据库连接。PostgreSQL 使用连接池;SQLite 每次新建连接(WAL)。"""
|
||||
if is_postgres():
|
||||
pool = _pg_pool_instance()
|
||||
raw = pool.getconn()
|
||||
try:
|
||||
with raw.cursor() as cur:
|
||||
cur.execute("SET TIME ZONE 'Asia/Shanghai'")
|
||||
raw.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return DbConnection("postgres", raw, from_pool=True)
|
||||
|
||||
def connect_db(path: str | None = None) -> sqlite3.Connection:
|
||||
db_path = path or DB_PATH
|
||||
conn = sqlite3.connect(db_path, timeout=30, check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA busy_timeout=30000")
|
||||
raw = sqlite3.connect(db_path, timeout=30, check_same_thread=False)
|
||||
raw.row_factory = sqlite3.Row
|
||||
raw.execute("PRAGMA busy_timeout=30000")
|
||||
try:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
raw.execute("PRAGMA journal_mode=WAL")
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
return conn
|
||||
return DbConnection("sqlite", raw)
|
||||
|
||||
|
||||
def close_pg_pool() -> None:
|
||||
global _pg_pool
|
||||
with _pg_pool_lock:
|
||||
if _pg_pool is not None:
|
||||
_pg_pool.close()
|
||||
_pg_pool = None
|
||||
|
||||
|
||||
def execute_retry(
|
||||
conn: sqlite3.Connection,
|
||||
conn: DbConnection,
|
||||
sql: str,
|
||||
params: tuple = (),
|
||||
*,
|
||||
retries: int = 6,
|
||||
base_delay: float = 0.05,
|
||||
) -> sqlite3.Cursor:
|
||||
"""遇 database is locked 时短暂退避重试。"""
|
||||
) -> DbCursor:
|
||||
"""遇锁冲突时短暂退避重试(SQLite locked / PG serialization)。"""
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
return conn.execute(sql, params)
|
||||
except sqlite3.OperationalError as exc:
|
||||
if "locked" not in str(exc).lower():
|
||||
except (OperationalError, PgOperationalError) as exc:
|
||||
msg = str(exc).lower()
|
||||
retryable = "locked" in msg or "serialize" in msg or "deadlock" in msg
|
||||
if not retryable:
|
||||
raise
|
||||
last_exc = exc
|
||||
if attempt < retries - 1:
|
||||
time.sleep(base_delay * (attempt + 1))
|
||||
if last_exc:
|
||||
raise last_exc
|
||||
raise sqlite3.OperationalError("database is locked")
|
||||
raise OperationalError("database is locked")
|
||||
|
||||
|
||||
def commit_retry(
|
||||
conn: sqlite3.Connection,
|
||||
conn: DbConnection,
|
||||
*,
|
||||
retries: int = 6,
|
||||
base_delay: float = 0.05,
|
||||
) -> None:
|
||||
"""遇 database is locked 时短暂退避重试 commit。"""
|
||||
"""遇锁冲突时短暂退避重试 commit。"""
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
conn.commit()
|
||||
return
|
||||
except sqlite3.OperationalError as exc:
|
||||
if "locked" not in str(exc).lower():
|
||||
except (OperationalError, PgOperationalError) as exc:
|
||||
msg = str(exc).lower()
|
||||
retryable = "locked" in msg or "serialize" in msg or "deadlock" in msg
|
||||
if not retryable:
|
||||
raise
|
||||
last_exc = exc
|
||||
if attempt < retries - 1:
|
||||
time.sleep(base_delay * (attempt + 1))
|
||||
if last_exc:
|
||||
raise last_exc
|
||||
raise sqlite3.OperationalError("database is locked")
|
||||
raise OperationalError("database is locked")
|
||||
|
||||
|
||||
def is_db_contention_error(exc: BaseException) -> bool:
|
||||
"""SQLite locked / PostgreSQL serialization / deadlock。"""
|
||||
msg = str(exc).lower()
|
||||
if isinstance(exc, sqlite3.OperationalError):
|
||||
return "locked" in msg
|
||||
if _PSYCOPG_OK and isinstance(exc, PgOperationalError):
|
||||
code = getattr(exc, "sqlstate", "") or ""
|
||||
if code in ("40001", "40P01", "55P03"):
|
||||
return True
|
||||
return any(x in msg for x in ("deadlock", "serialize", "lock"))
|
||||
return False
|
||||
|
||||
|
||||
def reset_backend_for_tests(backend: str | None = None) -> None:
|
||||
"""测试用:重置后端检测与连接池。"""
|
||||
global _backend, _pg_pool
|
||||
close_pg_pool()
|
||||
with _backend_lock:
|
||||
_backend = backend
|
||||
|
||||
Reference in New Issue
Block a user