8ebe1a3c77
Co-authored-by: Cursor <cursoragent@cursor.com>
229 lines
6.0 KiB
Python
229 lines
6.0 KiB
Python
# Copyright (c) 2025-2026 马建军. All rights reserved.
|
|
# 专有软件 — 未经授权禁止复制、传播、转售。
|
|
# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。
|
|
# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md
|
|
|
|
"""数据库连接:SQLite futures.db。"""
|
|
from __future__ import annotations
|
|
|
|
import sqlite3
|
|
import time
|
|
from typing import Any, Sequence
|
|
|
|
from modules.core.paths import DB_PATH as _ROOT_DB_PATH, KLINE_DB_PATH as _KLINE_DB_PATH
|
|
|
|
DB_PATH = _ROOT_DB_PATH
|
|
KLINE_DB_PATH = _KLINE_DB_PATH
|
|
|
|
OperationalError = sqlite3.OperationalError
|
|
IntegrityError = sqlite3.IntegrityError
|
|
|
|
|
|
def db_backend() -> str:
|
|
return "sqlite"
|
|
|
|
|
|
def is_postgres() -> bool:
|
|
return False
|
|
|
|
|
|
def database_label() -> str:
|
|
return f"SQLite ({DB_PATH})"
|
|
|
|
|
|
def adapt_sql(sql: str) -> str:
|
|
return sql
|
|
|
|
|
|
def is_benign_migration_error(exc: BaseException) -> bool:
|
|
"""ALTER TABLE 重复列等初始化迁移可忽略的错误。"""
|
|
return is_schema_migration_error(exc)
|
|
|
|
|
|
def is_schema_migration_error(exc: BaseException) -> bool:
|
|
"""init_db 增量迁移:缺表/缺列/重复列均可忽略。"""
|
|
msg = str(exc).lower()
|
|
if any(
|
|
x in msg
|
|
for x in (
|
|
"duplicate column",
|
|
"already exists",
|
|
"duplicate key",
|
|
"no such table",
|
|
"does not exist",
|
|
)
|
|
):
|
|
return True
|
|
if isinstance(exc, sqlite3.OperationalError) and (
|
|
"duplicate column" in msg or "no such table" in msg
|
|
):
|
|
return True
|
|
return False
|
|
|
|
|
|
def is_missing_relation_error(exc: BaseException) -> bool:
|
|
"""表/视图不存在。"""
|
|
if is_schema_migration_error(exc):
|
|
msg = str(exc).lower()
|
|
return any(x in msg for x in ("no such table", "does not exist"))
|
|
return False
|
|
|
|
|
|
def rollback_if_postgres(conn: "DbConnection") -> None:
|
|
"""兼容旧调用;当前仅使用 SQLite。"""
|
|
return
|
|
|
|
|
|
class DbCursor:
|
|
"""统一 cursor:兼容 sqlite3 的 execute / fetchone / lastrowid。"""
|
|
|
|
def __init__(self, raw_cursor: Any, raw_conn: Any) -> None:
|
|
self._cur = raw_cursor
|
|
self._conn = raw_conn
|
|
self.lastrowid: int | None = None
|
|
self.rowcount: int = 0
|
|
|
|
def execute(self, sql: str, params: Sequence[Any] | None = None) -> "DbCursor":
|
|
params = params or ()
|
|
self._cur.execute(sql, params)
|
|
self.rowcount = int(getattr(self._cur, "rowcount", 0) or 0)
|
|
self.lastrowid = getattr(self._cur, "lastrowid", None)
|
|
return self
|
|
|
|
def fetchone(self) -> Any:
|
|
return self._cur.fetchone()
|
|
|
|
def fetchall(self) -> list[Any]:
|
|
return self._cur.fetchall()
|
|
|
|
def close(self) -> None:
|
|
try:
|
|
self._cur.close()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
class DbConnection:
|
|
"""统一连接:execute / commit / close,接口对齐 sqlite3.Connection。"""
|
|
|
|
def __init__(self, raw_conn: Any) -> None:
|
|
self._conn = raw_conn
|
|
self.row_factory = None
|
|
|
|
def execute(self, sql: str, params: Sequence[Any] | None = None) -> DbCursor:
|
|
cur = self.cursor()
|
|
try:
|
|
return cur.execute(sql, params)
|
|
except Exception:
|
|
raise
|
|
|
|
def cursor(self) -> DbCursor:
|
|
return DbCursor(self._conn.cursor(), self._conn)
|
|
|
|
def commit(self) -> None:
|
|
self._conn.commit()
|
|
|
|
def rollback(self) -> None:
|
|
self._conn.rollback()
|
|
|
|
def close(self) -> None:
|
|
try:
|
|
self._conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
def __enter__(self) -> "DbConnection":
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb) -> None:
|
|
if exc:
|
|
try:
|
|
self.rollback()
|
|
except Exception:
|
|
pass
|
|
else:
|
|
try:
|
|
self.commit()
|
|
except Exception:
|
|
pass
|
|
self.close()
|
|
|
|
|
|
def connect_db(path: str | None = None) -> DbConnection:
|
|
"""获取 SQLite 数据库连接(用毕 close)。"""
|
|
db_path = path or DB_PATH
|
|
raw = sqlite3.connect(db_path, timeout=30, check_same_thread=False)
|
|
raw.row_factory = sqlite3.Row
|
|
raw.execute("PRAGMA busy_timeout=30000")
|
|
try:
|
|
raw.execute("PRAGMA journal_mode=WAL")
|
|
except sqlite3.OperationalError:
|
|
pass
|
|
return DbConnection(raw)
|
|
|
|
|
|
def close_pg_pool() -> None:
|
|
"""兼容旧调用。"""
|
|
return
|
|
|
|
|
|
def execute_retry(
|
|
conn: DbConnection,
|
|
sql: str,
|
|
params: tuple = (),
|
|
*,
|
|
retries: int = 6,
|
|
base_delay: float = 0.05,
|
|
) -> DbCursor:
|
|
"""遇锁冲突时短暂退避重试(SQLite locked)。"""
|
|
last_exc: Exception | None = None
|
|
for attempt in range(retries):
|
|
try:
|
|
return conn.execute(sql, params)
|
|
except OperationalError as exc:
|
|
msg = str(exc).lower()
|
|
if "locked" not in msg:
|
|
raise
|
|
last_exc = exc
|
|
if attempt < retries - 1:
|
|
time.sleep(base_delay * (attempt + 1))
|
|
if last_exc:
|
|
raise last_exc
|
|
raise OperationalError("database is locked")
|
|
|
|
|
|
def commit_retry(
|
|
conn: DbConnection,
|
|
*,
|
|
retries: int = 6,
|
|
base_delay: float = 0.05,
|
|
) -> None:
|
|
"""遇锁冲突时短暂退避重试 commit。"""
|
|
last_exc: Exception | None = None
|
|
for attempt in range(retries):
|
|
try:
|
|
conn.commit()
|
|
return
|
|
except OperationalError as exc:
|
|
msg = str(exc).lower()
|
|
if "locked" not in msg:
|
|
raise
|
|
last_exc = exc
|
|
if attempt < retries - 1:
|
|
time.sleep(base_delay * (attempt + 1))
|
|
if last_exc:
|
|
raise last_exc
|
|
raise OperationalError("database is locked")
|
|
|
|
|
|
def is_db_contention_error(exc: BaseException) -> bool:
|
|
"""SQLite locked。"""
|
|
if isinstance(exc, sqlite3.OperationalError):
|
|
return "locked" in str(exc).lower()
|
|
return False
|
|
|
|
|
|
def reset_backend_for_tests(backend: str | None = None) -> None:
|
|
"""兼容旧测试调用。"""
|
|
return
|