Files
qihuo/db_conn.py
T

369 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright (c) 2025-2026 马建军. All rights reserved.
# 专有软件 — 未经授权禁止复制、传播、转售。
# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。
# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md
"""数据库连接:开发默认 SQLite,生产推荐 PostgreSQLDATABASE_URL)。"""
from __future__ import annotations
import os
import re
import sqlite3
import threading
import time
from typing import Any, Iterable, Optional, Sequence
DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "futures.db")
_backend_lock = threading.Lock()
_backend: Optional[str] = None
_pg_pool = None
_pg_pool_lock = threading.Lock()
try:
import psycopg
from psycopg import OperationalError as PgOperationalError
from psycopg import IntegrityError as PgIntegrityError
from psycopg.rows import dict_row
from psycopg_pool import ConnectionPool
_PSYCOPG_OK = True
except ImportError:
psycopg = None # type: ignore[assignment]
PgOperationalError = Exception # type: ignore[misc,assignment]
PgIntegrityError = Exception # type: ignore[misc,assignment]
dict_row = None # type: ignore[assignment]
ConnectionPool = None # type: ignore[misc,assignment]
_PSYCOPG_OK = False
OperationalError = sqlite3.OperationalError
IntegrityError = sqlite3.IntegrityError
def db_backend() -> str:
"""``sqlite`` 或 ``postgres``。"""
global _backend
if _backend is not None:
return _backend
with _backend_lock:
if _backend is not None:
return _backend
url = (os.getenv("DATABASE_URL") or "").strip()
if url.startswith(("postgresql://", "postgres://")):
if not _PSYCOPG_OK:
raise RuntimeError(
"已配置 DATABASE_URL 但未安装 psycopg,请执行: pip install 'psycopg[binary]' psycopg-pool"
)
_backend = "postgres"
else:
_backend = "sqlite"
return _backend
def is_postgres() -> bool:
return db_backend() == "postgres"
def database_label() -> str:
if is_postgres():
url = (os.getenv("DATABASE_URL") or "").strip()
host = url.split("@")[-1].split("/")[0] if "@" in url else "postgresql"
return f"PostgreSQL ({host})"
return f"SQLite ({DB_PATH})"
def adapt_sql(sql: str) -> str:
"""将 SQLite 风格 SQL 适配为当前后端。"""
if not is_postgres():
return sql
out = sql
out = re.sub(
r"\bINTEGER PRIMARY KEY AUTOINCREMENT\b",
"SERIAL PRIMARY KEY",
out,
flags=re.IGNORECASE,
)
out = re.sub(r"\bAUTOINCREMENT\b", "", out, flags=re.IGNORECASE)
out = re.sub(r'DEFAULT\s+"([^"]*)"', r"DEFAULT '\1'", out, flags=re.IGNORECASE)
if "?" in out:
out = out.replace("?", "%s")
return out
def is_benign_migration_error(exc: BaseException) -> bool:
"""ALTER TABLE 重复列等初始化迁移可忽略的错误。"""
if is_schema_migration_error(exc):
return True
return False
def is_schema_migration_error(exc: BaseException) -> bool:
"""init_db 增量迁移:缺表/缺列/重复列均可忽略。"""
msg = str(exc).lower()
if any(
x in msg
for x in (
"duplicate column",
"already exists",
"duplicate key",
"no such table",
"does not exist",
"undefined table",
"undefined column",
)
):
return True
if isinstance(exc, sqlite3.OperationalError) and (
"duplicate column" in msg or "no such table" in msg
):
return True
if _PSYCOPG_OK and isinstance(exc, PgOperationalError):
code = getattr(exc, "sqlstate", "") or ""
if code in ("42701", "42P07", "42P01", "42703"):
return True
return False
def rollback_if_postgres(conn: "DbConnection") -> None:
if is_postgres():
try:
conn.rollback()
except Exception:
pass
class DbCursor:
"""统一 cursor:兼容 sqlite3 的 execute / fetchone / lastrowid。"""
def __init__(self, backend: str, raw_cursor: Any, raw_conn: Any) -> None:
self._backend = backend
self._cur = raw_cursor
self._conn = raw_conn
self.lastrowid: Optional[int] = None
self.rowcount: int = 0
def execute(self, sql: str, params: Sequence[Any] | None = None) -> "DbCursor":
sql = adapt_sql(sql)
params = params or ()
self._cur.execute(sql, params)
self.rowcount = int(getattr(self._cur, "rowcount", 0) or 0)
self.lastrowid = getattr(self._cur, "lastrowid", None)
if self.lastrowid is None and is_postgres():
if re.match(r"^\s*INSERT\b", sql, re.IGNORECASE):
try:
row = self._cur.fetchone()
if row is not None:
if isinstance(row, dict):
self.lastrowid = int(row.get("id") or row.get("Id") or 0) or None
else:
self.lastrowid = int(row[0])
except Exception:
try:
self._cur.execute("SELECT lastval()")
lv = self._cur.fetchone()
if lv:
self.lastrowid = int(lv[0] if not isinstance(lv, dict) else lv["lastval"])
except Exception:
pass
return self
def fetchone(self) -> Any:
return self._cur.fetchone()
def fetchall(self) -> list[Any]:
return self._cur.fetchall()
def close(self) -> None:
try:
self._cur.close()
except Exception:
pass
class DbConnection:
"""统一连接:execute / commit / close,接口对齐 sqlite3.Connection。"""
def __init__(
self,
backend: str,
raw_conn: Any,
*,
from_pool: bool = False,
) -> None:
self._backend = backend
self._conn = raw_conn
self._from_pool = from_pool
self.row_factory = None
def execute(self, sql: str, params: Sequence[Any] | None = None) -> DbCursor:
cur = self.cursor()
return cur.execute(sql, params)
def cursor(self) -> DbCursor:
if self._backend == "sqlite":
return DbCursor(self._backend, self._conn.cursor(), self._conn)
raw = self._conn.cursor(row_factory=dict_row)
return DbCursor(self._backend, raw, self._conn)
def commit(self) -> None:
self._conn.commit()
def rollback(self) -> None:
self._conn.rollback()
def close(self) -> None:
if self._backend == "postgres" and self._from_pool:
try:
self._conn.rollback()
except Exception:
pass
try:
self._conn.close()
except Exception:
pass
return
try:
self._conn.close()
except Exception:
pass
def __enter__(self) -> "DbConnection":
return self
def __exit__(self, exc_type, exc, tb) -> None:
if exc:
try:
self.rollback()
except Exception:
pass
else:
try:
self.commit()
except Exception:
pass
self.close()
def _pg_pool_instance() -> ConnectionPool:
global _pg_pool
if _pg_pool is not None:
return _pg_pool
with _pg_pool_lock:
if _pg_pool is not None:
return _pg_pool
url = (os.getenv("DATABASE_URL") or "").strip()
min_size = max(1, int(os.getenv("PG_POOL_MIN", "2") or 2))
max_size = max(min_size, int(os.getenv("PG_POOL_MAX", "20") or 20))
_pg_pool = ConnectionPool(
conninfo=url,
min_size=min_size,
max_size=max_size,
kwargs={"row_factory": dict_row},
open=True,
)
return _pg_pool
def connect_db(path: str | None = None) -> DbConnection:
"""获取数据库连接。PostgreSQL 使用连接池;SQLite 每次新建连接(WAL)。"""
if is_postgres():
pool = _pg_pool_instance()
raw = pool.getconn()
try:
with raw.cursor() as cur:
cur.execute("SET TIME ZONE 'Asia/Shanghai'")
raw.commit()
except Exception:
pass
return DbConnection("postgres", raw, from_pool=True)
db_path = path or DB_PATH
raw = sqlite3.connect(db_path, timeout=30, check_same_thread=False)
raw.row_factory = sqlite3.Row
raw.execute("PRAGMA busy_timeout=30000")
try:
raw.execute("PRAGMA journal_mode=WAL")
except sqlite3.OperationalError:
pass
return DbConnection("sqlite", raw)
def close_pg_pool() -> None:
global _pg_pool
with _pg_pool_lock:
if _pg_pool is not None:
_pg_pool.close()
_pg_pool = None
def execute_retry(
conn: DbConnection,
sql: str,
params: tuple = (),
*,
retries: int = 6,
base_delay: float = 0.05,
) -> DbCursor:
"""遇锁冲突时短暂退避重试(SQLite locked / PG serialization)。"""
last_exc: Exception | None = None
for attempt in range(retries):
try:
return conn.execute(sql, params)
except (OperationalError, PgOperationalError) as exc:
msg = str(exc).lower()
retryable = "locked" in msg or "serialize" in msg or "deadlock" in msg
if not retryable:
raise
last_exc = exc
if attempt < retries - 1:
time.sleep(base_delay * (attempt + 1))
if last_exc:
raise last_exc
raise OperationalError("database is locked")
def commit_retry(
conn: DbConnection,
*,
retries: int = 6,
base_delay: float = 0.05,
) -> None:
"""遇锁冲突时短暂退避重试 commit。"""
last_exc: Exception | None = None
for attempt in range(retries):
try:
conn.commit()
return
except (OperationalError, PgOperationalError) as exc:
msg = str(exc).lower()
retryable = "locked" in msg or "serialize" in msg or "deadlock" in msg
if not retryable:
raise
last_exc = exc
if attempt < retries - 1:
time.sleep(base_delay * (attempt + 1))
if last_exc:
raise last_exc
raise OperationalError("database is locked")
def is_db_contention_error(exc: BaseException) -> bool:
"""SQLite locked / PostgreSQL serialization / deadlock。"""
msg = str(exc).lower()
if isinstance(exc, sqlite3.OperationalError):
return "locked" in msg
if _PSYCOPG_OK and isinstance(exc, PgOperationalError):
code = getattr(exc, "sqlstate", "") or ""
if code in ("40001", "40P01", "55P03"):
return True
return any(x in msg for x in ("deadlock", "serialize", "lock"))
return False
def reset_backend_for_tests(backend: str | None = None) -> None:
"""测试用:重置后端检测与连接池。"""
global _backend, _pg_pool
close_pg_pool()
with _backend_lock:
_backend = backend