修复bug
This commit is contained in:
@@ -27,7 +27,7 @@ import threading
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple
|
||||
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Tuple, TypeVar
|
||||
|
||||
try:
|
||||
from typing import Annotated # Python 3.9+
|
||||
@@ -35,6 +35,16 @@ except ImportError: # Python 3.8
|
||||
from typing_extensions import Annotated
|
||||
|
||||
import httpx
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
async def run_in_thread(func: Callable[..., _T], /, *args: Any) -> _T:
|
||||
"""在线程池执行阻塞函数(兼容 Python 3.8,无 asyncio.to_thread)。"""
|
||||
if hasattr(asyncio, "to_thread"):
|
||||
return await run_in_thread(func, *args) # type: ignore[attr-defined]
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, func, *args)
|
||||
from fastapi import Depends, FastAPI, HTTPException, Query, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
|
||||
@@ -634,7 +644,7 @@ async def record_api_log(
|
||||
usage: Dict[str, int],
|
||||
node_id: Optional[str] = None,
|
||||
) -> None:
|
||||
await asyncio.to_thread(
|
||||
await run_in_thread(
|
||||
_record_api_log_sync,
|
||||
client_ip,
|
||||
model,
|
||||
@@ -1541,14 +1551,14 @@ async def api_me(
|
||||
|
||||
@app.get("/api/models/cards")
|
||||
async def api_model_cards() -> List[Dict[str, Any]]:
|
||||
return await asyncio.to_thread(build_model_cards)
|
||||
return await run_in_thread(build_model_cards)
|
||||
|
||||
|
||||
@app.get("/api/nodes")
|
||||
async def api_nodes(
|
||||
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
|
||||
) -> List[Dict[str, Any]]:
|
||||
return await asyncio.to_thread(list_nodes_with_status)
|
||||
return await run_in_thread(list_nodes_with_status)
|
||||
|
||||
|
||||
@app.post("/api/nodes")
|
||||
@@ -1566,7 +1576,7 @@ async def api_nodes_create(
|
||||
"models": [{"id": m.id.strip(), "label": (m.label or m.id).strip()} for m in body.models],
|
||||
}
|
||||
nodes = list(_NODES.get("nodes", [])) + [node]
|
||||
await asyncio.to_thread(save_nodes_config, {"nodes": nodes})
|
||||
await run_in_thread(save_nodes_config, {"nodes": nodes})
|
||||
return node
|
||||
|
||||
|
||||
@@ -1589,7 +1599,7 @@ async def api_nodes_update(
|
||||
"max_concurrent": body.max_concurrent,
|
||||
"models": [{"id": m.id.strip(), "label": (m.label or m.id).strip()} for m in body.models],
|
||||
}
|
||||
await asyncio.to_thread(save_nodes_config, {"nodes": nodes})
|
||||
await run_in_thread(save_nodes_config, {"nodes": nodes})
|
||||
return nodes[idx]
|
||||
|
||||
|
||||
@@ -1601,7 +1611,7 @@ async def api_nodes_delete(
|
||||
nodes = [n for n in _NODES.get("nodes", []) if str(n.get("id")) != node_id]
|
||||
if len(nodes) == len(_NODES.get("nodes", [])):
|
||||
raise HTTPException(status_code=404, detail="节点不存在")
|
||||
await asyncio.to_thread(save_nodes_config, {"nodes": nodes})
|
||||
await run_in_thread(save_nodes_config, {"nodes": nodes})
|
||||
return JSONResponse({"ok": True})
|
||||
|
||||
|
||||
@@ -1628,14 +1638,14 @@ async def api_nodes_test(
|
||||
async def api_stats_summary(
|
||||
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
|
||||
) -> Dict[str, Any]:
|
||||
return await asyncio.to_thread(_query_stats_summary)
|
||||
return await run_in_thread(_query_stats_summary)
|
||||
|
||||
|
||||
@app.get("/api/stats/ips")
|
||||
async def api_stats_ips(
|
||||
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
|
||||
) -> List[Dict[str, Any]]:
|
||||
return await asyncio.to_thread(_query_stats_ips)
|
||||
return await run_in_thread(_query_stats_ips)
|
||||
|
||||
|
||||
@app.get("/api/stats/billing")
|
||||
@@ -1643,7 +1653,7 @@ async def api_stats_billing(
|
||||
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
|
||||
days: int = Query(30, ge=1, le=365),
|
||||
) -> Dict[str, Any]:
|
||||
return await asyncio.to_thread(_query_stats_billing, days)
|
||||
return await run_in_thread(_query_stats_billing, days)
|
||||
|
||||
|
||||
@app.get("/api/stats/logs")
|
||||
@@ -1652,7 +1662,7 @@ async def api_stats_logs(
|
||||
limit: int = Query(50, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
) -> Dict[str, Any]:
|
||||
items, total = await asyncio.to_thread(_query_stats_logs, limit, offset)
|
||||
items, total = await run_in_thread(_query_stats_logs, limit, offset)
|
||||
return {"items": items, "total": total, "limit": limit, "offset": offset}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user