# Copyright (c) 2025-2026 马建军. All rights reserved. # 专有软件 — 未经授权禁止复制、传播、转售。 # 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 # 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md """数据库连接:开发默认 SQLite,生产推荐 PostgreSQL(DATABASE_URL)。""" from __future__ import annotations import os import re import sqlite3 import threading import time from typing import Any, Iterable, Optional, Sequence DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "futures.db") _backend_lock = threading.Lock() _backend: Optional[str] = None _pg_pool = None _pg_pool_lock = threading.Lock() try: import psycopg from psycopg import OperationalError as PgOperationalError from psycopg import IntegrityError as PgIntegrityError from psycopg.rows import dict_row from psycopg_pool import ConnectionPool _PSYCOPG_OK = True except ImportError: psycopg = None # type: ignore[assignment] PgOperationalError = Exception # type: ignore[misc,assignment] PgIntegrityError = Exception # type: ignore[misc,assignment] dict_row = None # type: ignore[assignment] ConnectionPool = None # type: ignore[misc,assignment] _PSYCOPG_OK = False OperationalError = sqlite3.OperationalError IntegrityError = sqlite3.IntegrityError def db_backend() -> str: """``sqlite`` 或 ``postgres``。""" global _backend if _backend is not None: return _backend with _backend_lock: if _backend is not None: return _backend url = (os.getenv("DATABASE_URL") or "").strip() if url.startswith(("postgresql://", "postgres://")): if not _PSYCOPG_OK: raise RuntimeError( "已配置 DATABASE_URL 但未安装 psycopg,请执行: pip install 'psycopg[binary]' psycopg-pool" ) _backend = "postgres" else: _backend = "sqlite" return _backend def is_postgres() -> bool: return db_backend() == "postgres" def database_label() -> str: if is_postgres(): url = (os.getenv("DATABASE_URL") or "").strip() host = url.split("@")[-1].split("/")[0] if "@" in url else "postgresql" return f"PostgreSQL ({host})" return f"SQLite ({DB_PATH})" def adapt_sql(sql: str) -> str: """将 SQLite 风格 SQL 适配为当前后端。""" if not is_postgres(): return sql out = sql out = re.sub( r"\bINTEGER PRIMARY KEY AUTOINCREMENT\b", "SERIAL PRIMARY KEY", out, flags=re.IGNORECASE, ) out = re.sub(r"\bAUTOINCREMENT\b", "", out, flags=re.IGNORECASE) if "?" in out: out = out.replace("?", "%s") return out def is_benign_migration_error(exc: BaseException) -> bool: """ALTER TABLE 重复列等初始化迁移可忽略的错误。""" msg = str(exc).lower() if any( x in msg for x in ( "duplicate column", "already exists", "duplicate key", ) ): return True if isinstance(exc, sqlite3.OperationalError) and "duplicate column" in msg: return True if _PSYCOPG_OK and isinstance(exc, PgOperationalError): code = getattr(exc, "sqlstate", "") or "" if code in ("42701", "42P07"): # duplicate_column, duplicate_table return True return False class DbCursor: """统一 cursor:兼容 sqlite3 的 execute / fetchone / lastrowid。""" def __init__(self, backend: str, raw_cursor: Any, raw_conn: Any) -> None: self._backend = backend self._cur = raw_cursor self._conn = raw_conn self.lastrowid: Optional[int] = None self.rowcount: int = 0 def execute(self, sql: str, params: Sequence[Any] | None = None) -> "DbCursor": sql = adapt_sql(sql) params = params or () self._cur.execute(sql, params) self.rowcount = int(getattr(self._cur, "rowcount", 0) or 0) self.lastrowid = getattr(self._cur, "lastrowid", None) if self.lastrowid is None and is_postgres(): if re.match(r"^\s*INSERT\b", sql, re.IGNORECASE): try: row = self._cur.fetchone() if row is not None: if isinstance(row, dict): self.lastrowid = int(row.get("id") or row.get("Id") or 0) or None else: self.lastrowid = int(row[0]) except Exception: try: self._cur.execute("SELECT lastval()") lv = self._cur.fetchone() if lv: self.lastrowid = int(lv[0] if not isinstance(lv, dict) else lv["lastval"]) except Exception: pass return self def fetchone(self) -> Any: return self._cur.fetchone() def fetchall(self) -> list[Any]: return self._cur.fetchall() def close(self) -> None: try: self._cur.close() except Exception: pass class DbConnection: """统一连接:execute / commit / close,接口对齐 sqlite3.Connection。""" def __init__( self, backend: str, raw_conn: Any, *, from_pool: bool = False, ) -> None: self._backend = backend self._conn = raw_conn self._from_pool = from_pool self.row_factory = None def execute(self, sql: str, params: Sequence[Any] | None = None) -> DbCursor: cur = self.cursor() return cur.execute(sql, params) def cursor(self) -> DbCursor: if self._backend == "sqlite": return DbCursor(self._backend, self._conn.cursor(), self._conn) raw = self._conn.cursor(row_factory=dict_row) return DbCursor(self._backend, raw, self._conn) def commit(self) -> None: self._conn.commit() def rollback(self) -> None: self._conn.rollback() def close(self) -> None: if self._backend == "postgres" and self._from_pool: try: self._conn.rollback() except Exception: pass try: self._conn.close() except Exception: pass return try: self._conn.close() except Exception: pass def __enter__(self) -> "DbConnection": return self def __exit__(self, exc_type, exc, tb) -> None: if exc: try: self.rollback() except Exception: pass else: try: self.commit() except Exception: pass self.close() def _pg_pool_instance() -> ConnectionPool: global _pg_pool if _pg_pool is not None: return _pg_pool with _pg_pool_lock: if _pg_pool is not None: return _pg_pool url = (os.getenv("DATABASE_URL") or "").strip() min_size = max(1, int(os.getenv("PG_POOL_MIN", "2") or 2)) max_size = max(min_size, int(os.getenv("PG_POOL_MAX", "20") or 20)) _pg_pool = ConnectionPool( conninfo=url, min_size=min_size, max_size=max_size, kwargs={"row_factory": dict_row}, open=True, ) return _pg_pool def connect_db(path: str | None = None) -> DbConnection: """获取数据库连接。PostgreSQL 使用连接池;SQLite 每次新建连接(WAL)。""" if is_postgres(): pool = _pg_pool_instance() raw = pool.getconn() try: with raw.cursor() as cur: cur.execute("SET TIME ZONE 'Asia/Shanghai'") raw.commit() except Exception: pass return DbConnection("postgres", raw, from_pool=True) db_path = path or DB_PATH raw = sqlite3.connect(db_path, timeout=30, check_same_thread=False) raw.row_factory = sqlite3.Row raw.execute("PRAGMA busy_timeout=30000") try: raw.execute("PRAGMA journal_mode=WAL") except sqlite3.OperationalError: pass return DbConnection("sqlite", raw) def close_pg_pool() -> None: global _pg_pool with _pg_pool_lock: if _pg_pool is not None: _pg_pool.close() _pg_pool = None def execute_retry( conn: DbConnection, sql: str, params: tuple = (), *, retries: int = 6, base_delay: float = 0.05, ) -> DbCursor: """遇锁冲突时短暂退避重试(SQLite locked / PG serialization)。""" last_exc: Exception | None = None for attempt in range(retries): try: return conn.execute(sql, params) except (OperationalError, PgOperationalError) as exc: msg = str(exc).lower() retryable = "locked" in msg or "serialize" in msg or "deadlock" in msg if not retryable: raise last_exc = exc if attempt < retries - 1: time.sleep(base_delay * (attempt + 1)) if last_exc: raise last_exc raise OperationalError("database is locked") def commit_retry( conn: DbConnection, *, retries: int = 6, base_delay: float = 0.05, ) -> None: """遇锁冲突时短暂退避重试 commit。""" last_exc: Exception | None = None for attempt in range(retries): try: conn.commit() return except (OperationalError, PgOperationalError) as exc: msg = str(exc).lower() retryable = "locked" in msg or "serialize" in msg or "deadlock" in msg if not retryable: raise last_exc = exc if attempt < retries - 1: time.sleep(base_delay * (attempt + 1)) if last_exc: raise last_exc raise OperationalError("database is locked") def is_db_contention_error(exc: BaseException) -> bool: """SQLite locked / PostgreSQL serialization / deadlock。""" msg = str(exc).lower() if isinstance(exc, sqlite3.OperationalError): return "locked" in msg if _PSYCOPG_OK and isinstance(exc, PgOperationalError): code = getattr(exc, "sqlstate", "") or "" if code in ("40001", "40P01", "55P03"): return True return any(x in msg for x in ("deadlock", "serialize", "lock")) return False def reset_backend_for_tests(backend: str | None = None) -> None: """测试用:重置后端检测与连接池。""" global _backend, _pg_pool close_pg_pool() with _backend_lock: _backend = backend