Files
qihuo/modules/core/db_conn.py
T
dekun e5a586f903 Restructure into modules/ with single-process CTP and config/ layout.
Move business code under modules/, env template to config/, PM2 single qihuo process, and _legacy shims for old imports.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-07-01 14:42:16 +08:00

346 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright (c) 2025-2026 马建军. All rights reserved.
# 专有软件 — 未经授权禁止复制、传播、转售。
# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。
# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md
"""数据库连接:开发默认 SQLite,生产推荐 PostgreSQLDATABASE_URL)。"""
from __future__ import annotations
import os
import re
import sqlite3
import threading
import time
from typing import Any, Iterable, Optional, Sequence
from modules.core.paths import DB_PATH as _ROOT_DB_PATH
DB_PATH = _ROOT_DB_PATH
_backend_lock = threading.Lock()
_backend: Optional[str] = None
try:
import psycopg
from psycopg import OperationalError as PgOperationalError
from psycopg import IntegrityError as PgIntegrityError
from psycopg.rows import dict_row
_PSYCOPG_OK = True
except ImportError:
psycopg = None # type: ignore[assignment]
PgOperationalError = Exception # type: ignore[misc,assignment]
PgIntegrityError = Exception # type: ignore[misc,assignment]
dict_row = None # type: ignore[assignment]
_PSYCOPG_OK = False
OperationalError = sqlite3.OperationalError
IntegrityError = sqlite3.IntegrityError
def db_backend() -> str:
"""``sqlite`` 或 ``postgres``。"""
global _backend
if _backend is not None:
return _backend
with _backend_lock:
if _backend is not None:
return _backend
url = (os.getenv("DATABASE_URL") or "").strip()
if url.startswith(("postgresql://", "postgres://")):
if not _PSYCOPG_OK:
raise RuntimeError(
"已配置 DATABASE_URL 但未安装 psycopg,请执行: pip install 'psycopg[binary]'"
)
_backend = "postgres"
else:
_backend = "sqlite"
return _backend
def is_postgres() -> bool:
return db_backend() == "postgres"
def database_label() -> str:
if is_postgres():
url = (os.getenv("DATABASE_URL") or "").strip()
host = url.split("@")[-1].split("/")[0] if "@" in url else "postgresql"
return f"PostgreSQL ({host})"
return f"SQLite ({DB_PATH})"
def adapt_sql(sql: str) -> str:
"""将 SQLite 风格 SQL 适配为当前后端。"""
if not is_postgres():
return sql
out = sql
out = re.sub(
r"\bINTEGER PRIMARY KEY AUTOINCREMENT\b",
"SERIAL PRIMARY KEY",
out,
flags=re.IGNORECASE,
)
out = re.sub(r"\bAUTOINCREMENT\b", "", out, flags=re.IGNORECASE)
out = re.sub(r'DEFAULT\s+"([^"]*)"', r"DEFAULT '\1'", out, flags=re.IGNORECASE)
if "?" in out:
out = out.replace("?", "%s")
return out
def is_benign_migration_error(exc: BaseException) -> bool:
"""ALTER TABLE 重复列等初始化迁移可忽略的错误。"""
if is_schema_migration_error(exc):
return True
return False
def is_schema_migration_error(exc: BaseException) -> bool:
"""init_db 增量迁移:缺表/缺列/重复列均可忽略。"""
msg = str(exc).lower()
if any(
x in msg
for x in (
"duplicate column",
"already exists",
"duplicate key",
"no such table",
"does not exist",
"undefined table",
"undefined column",
)
):
return True
if isinstance(exc, sqlite3.OperationalError) and (
"duplicate column" in msg or "no such table" in msg
):
return True
if _PSYCOPG_OK and isinstance(exc, PgOperationalError):
code = getattr(exc, "sqlstate", "") or ""
if code in ("42701", "42P07", "42P01", "42703"):
return True
return False
def is_missing_relation_error(exc: BaseException) -> bool:
"""表/视图不存在。"""
if is_schema_migration_error(exc):
msg = str(exc).lower()
return any(x in msg for x in ("no such table", "does not exist", "undefined table"))
return False
def rollback_if_postgres(conn: "DbConnection") -> None:
if is_postgres():
try:
conn.rollback()
except Exception:
pass
class DbCursor:
"""统一 cursor:兼容 sqlite3 的 execute / fetchone / lastrowid。"""
def __init__(self, backend: str, raw_cursor: Any, raw_conn: Any) -> None:
self._backend = backend
self._cur = raw_cursor
self._conn = raw_conn
self.lastrowid: Optional[int] = None
self.rowcount: int = 0
def execute(self, sql: str, params: Sequence[Any] | None = None) -> "DbCursor":
sql = adapt_sql(sql)
params = params or ()
self._cur.execute(sql, params)
self.rowcount = int(getattr(self._cur, "rowcount", 0) or 0)
self.lastrowid = getattr(self._cur, "lastrowid", None)
if self.lastrowid is None and is_postgres():
if re.match(r"^\s*INSERT\b", sql, re.IGNORECASE):
try:
row = self._cur.fetchone()
if row is not None:
if isinstance(row, dict):
self.lastrowid = int(row.get("id") or row.get("Id") or 0) or None
else:
self.lastrowid = int(row[0])
except Exception:
try:
self._cur.execute("SELECT lastval()")
lv = self._cur.fetchone()
if lv:
self.lastrowid = int(lv[0] if not isinstance(lv, dict) else lv["lastval"])
except Exception:
pass
return self
def fetchone(self) -> Any:
return self._cur.fetchone()
def fetchall(self) -> list[Any]:
return self._cur.fetchall()
def close(self) -> None:
try:
self._cur.close()
except Exception:
pass
class DbConnection:
"""统一连接:execute / commit / close,接口对齐 sqlite3.Connection。"""
def __init__(
self,
backend: str,
raw_conn: Any,
*,
from_pool: bool = False,
) -> None:
self._backend = backend
self._conn = raw_conn
self._from_pool = from_pool
self.row_factory = None
def execute(self, sql: str, params: Sequence[Any] | None = None) -> DbCursor:
cur = self.cursor()
try:
return cur.execute(sql, params)
except Exception:
rollback_if_postgres(self)
raise
def cursor(self) -> DbCursor:
if self._backend == "sqlite":
return DbCursor(self._backend, self._conn.cursor(), self._conn)
raw = self._conn.cursor(row_factory=dict_row)
return DbCursor(self._backend, raw, self._conn)
def commit(self) -> None:
self._conn.commit()
def rollback(self) -> None:
self._conn.rollback()
def close(self) -> None:
try:
self._conn.close()
except Exception:
pass
def __enter__(self) -> "DbConnection":
return self
def __exit__(self, exc_type, exc, tb) -> None:
if exc:
try:
self.rollback()
except Exception:
pass
else:
try:
self.commit()
except Exception:
pass
self.close()
def connect_db(path: str | None = None) -> DbConnection:
"""获取数据库连接。PostgreSQL / SQLite 均为每次新建连接(用毕 close)。"""
if is_postgres():
url = (os.getenv("DATABASE_URL") or "").strip()
raw = psycopg.connect(url, row_factory=dict_row)
try:
with raw.cursor() as cur:
cur.execute("SET TIME ZONE 'Asia/Shanghai'")
raw.commit()
except Exception:
pass
return DbConnection("postgres", raw, from_pool=False)
db_path = path or DB_PATH
raw = sqlite3.connect(db_path, timeout=30, check_same_thread=False)
raw.row_factory = sqlite3.Row
raw.execute("PRAGMA busy_timeout=30000")
try:
raw.execute("PRAGMA journal_mode=WAL")
except sqlite3.OperationalError:
pass
return DbConnection("sqlite", raw)
def close_pg_pool() -> None:
"""兼容旧调用;当前 PostgreSQL 使用直连,无全局连接池。"""
return
def execute_retry(
conn: DbConnection,
sql: str,
params: tuple = (),
*,
retries: int = 6,
base_delay: float = 0.05,
) -> DbCursor:
"""遇锁冲突时短暂退避重试(SQLite locked / PG serialization)。"""
last_exc: Exception | None = None
for attempt in range(retries):
try:
return conn.execute(sql, params)
except (OperationalError, PgOperationalError) as exc:
msg = str(exc).lower()
retryable = "locked" in msg or "serialize" in msg or "deadlock" in msg
if not retryable:
raise
last_exc = exc
if attempt < retries - 1:
time.sleep(base_delay * (attempt + 1))
if last_exc:
raise last_exc
raise OperationalError("database is locked")
def commit_retry(
conn: DbConnection,
*,
retries: int = 6,
base_delay: float = 0.05,
) -> None:
"""遇锁冲突时短暂退避重试 commit。"""
last_exc: Exception | None = None
for attempt in range(retries):
try:
conn.commit()
return
except (OperationalError, PgOperationalError) as exc:
msg = str(exc).lower()
retryable = "locked" in msg or "serialize" in msg or "deadlock" in msg
if not retryable:
raise
last_exc = exc
if attempt < retries - 1:
time.sleep(base_delay * (attempt + 1))
if last_exc:
raise last_exc
raise OperationalError("database is locked")
def is_db_contention_error(exc: BaseException) -> bool:
"""SQLite locked / PostgreSQL serialization / deadlock。"""
msg = str(exc).lower()
if isinstance(exc, sqlite3.OperationalError):
return "locked" in msg
if _PSYCOPG_OK and isinstance(exc, PgOperationalError):
code = getattr(exc, "sqlstate", "") or ""
if code in ("40001", "40P01", "55P03"):
return True
return any(x in msg for x in ("deadlock", "serialize", "lock"))
return False
def reset_backend_for_tests(backend: str | None = None) -> None:
"""测试用:重置后端检测。"""
global _backend
close_pg_pool()
with _backend_lock:
_backend = backend