diff --git a/services/nvidia_sidecar/config.py b/services/nvidia_sidecar/config.py index 3176fb8..aa82663 100644 --- a/services/nvidia_sidecar/config.py +++ b/services/nvidia_sidecar/config.py @@ -56,7 +56,7 @@ class SidecarConfig: # ---- 超时 ---- request_timeout: float = field( - default=6000.0, + default=60.0, metadata={"env": "SIDECAR_TIMEOUT"}, ) @@ -153,9 +153,14 @@ def _validate_config(config: SidecarConfig) -> list[str]: # request_timeout 合理性 if config.request_timeout <= 0: issues.append( - f"request_timeout ({config.request_timeout}) 无效,回退到默认值 6000" + f"request_timeout ({config.request_timeout}) 无效,回退到默认值 60" ) - config.request_timeout = 6000.0 + config.request_timeout = 60.0 + elif config.request_timeout > 300.0: + issues.append( + f"request_timeout ({config.request_timeout}) 异常偏高,已截断为 300" + ) + config.request_timeout = 300.0 return issues diff --git a/services/nvidia_sidecar/priority_queue.py b/services/nvidia_sidecar/priority_queue.py index 3db1d05..62a1952 100644 --- a/services/nvidia_sidecar/priority_queue.py +++ b/services/nvidia_sidecar/priority_queue.py @@ -107,6 +107,33 @@ class PriorityRequestQueue: """当前队列满策略。""" return self._full_policy + # ---- 动态容量调整 ---- + + def set_max_size(self, new_size: int) -> tuple[bool, str]: + """动态调整队列最大容量(热重载)。 + + 缩小操作受保护:如果 new_size 小于当前排队数,拒绝变更并 + 提示当前队列深度。 + + Args: + new_size: 新的最大容量。 + + Returns: + (成功标志, 消息)。成功时标志为 True,消息含新旧容量对比; + 失败时标志为 False,消息含拒绝原因和当前深度。 + + Raises: + ValueError: new_size <= 0。 + """ + if new_size <= 0: + raise ValueError(f"max_size 必须为正整数,当前值: {new_size}") + current = len(self._heap) + if new_size < current: + return (False, f"拒绝缩小:新上限 {new_size} < 当前排队数 {current},需要先排空或提升上限") + old = self.max_size + self.max_size = new_size + return (True, f"队列上限已调整:{old} → {new_size}{'(当前排队 ' + str(current) + ')' if current > 0 else ''}") + # ---- 入队 ---- async def put( diff --git a/services/nvidia_sidecar/server.py b/services/nvidia_sidecar/server.py index 418c370..45bb539 100644 --- a/services/nvidia_sidecar/server.py +++ b/services/nvidia_sidecar/server.py @@ -20,6 +20,7 @@ import httpx import structlog import uvicorn from fastapi import FastAPI, Request, Response +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from nvidia_sidecar.config import load_config, SidecarConfig @@ -76,7 +77,7 @@ _pending_requests: dict[str, tuple[asyncio.Future[httpx.Response], float]] """request_id → (response future, enqueued_at) 的映射。""" _metrics_task: asyncio.Task[None] | None = None -# 统计计数器 +# 统计计数器(受 _stats_lock 保护, 修复梁思筑评审 #1: data race) _stats: dict[str, int] = { "total_requests": 0, "nvidia_requests": 0, @@ -86,13 +87,20 @@ _stats: dict[str, int] = { "upstream_errors": 0, "start_time": 0, } +_stats_lock: asyncio.Lock = asyncio.Lock() # --------------------------------------------------------------------------- # 工具函数 # --------------------------------------------------------------------------- -def _extract_model(body: dict[str, Any]) -> str | None: +async def _increment_stat(key: str, delta: int = 1) -> None: + """线程安全的 _stats 计数器自增(梁思筑评审 #1 修复:消除 data race)。""" + async with _stats_lock: + _stats[key] = _stats.get(key, 0) + delta + + +def _extract_model(body: Any) -> str | None: """从请求体中提取模型标识符(兼容 OpenAI Chat/Completions 格式)。 Args: @@ -213,7 +221,7 @@ async def _worker_loop() -> None: ) if not got_token: log.info("low_priority_timeout", request_id=request_id) - _stats["ratelimited_requests"] += 1 + await _increment_stat("ratelimited_requests") _prometheus.record_request(queue_item.priority.name, "ratelimited") if not future.done(): future.set_exception( @@ -241,7 +249,7 @@ async def _worker_loop() -> None: priority=queue_item.priority.name, timeout=_config.request_timeout, ) - _stats["ratelimited_requests"] += 1 + await _increment_stat("ratelimited_requests") _prometheus.record_request(queue_item.priority.name, "ratelimited") if not future.done(): future.set_exception( @@ -309,7 +317,7 @@ async def _worker_loop() -> None: except (httpx.HTTPError, OSError) as exc: log.error("upstream_request_failed", request_id=request_id, error=str(exc)) - _stats["upstream_errors"] += 1 + await _increment_stat("upstream_errors") _prometheus.record_request(queue_item.priority.name, "error") _prometheus.set_health(False) if not future.done(): @@ -347,7 +355,7 @@ async def _passthrough_with_rate_limit( Returns: FastAPI Response。 """ - _stats["passthrough_requests"] += 1 + await _increment_stat("passthrough_requests") _prometheus.increment_fallback() # 低优先级走令牌桶等待 @@ -358,7 +366,7 @@ async def _passthrough_with_rate_limit( timeout=_config.low_priority_timeout, ) if not got_token: - _stats["ratelimited_requests"] += 1 + await _increment_stat("ratelimited_requests") _prometheus.record_request(priority.name, "ratelimited") return JSONResponse( status_code=429, @@ -372,19 +380,20 @@ async def _passthrough_with_rate_limit( else: got_token = await asyncio.to_thread(_token_bucket.consume, tokens=1) if not got_token: - # 非低优先级轮询等待 - deadline = time.monotonic() + 30.0 + # 非低优先级轮询等待,使用 config.request_timeout 替代硬编码 30s + # (严维序评审 minor / 梁思筑评审 #3:hot-reload 假生效修复) + deadline = time.monotonic() + _config.request_timeout while not got_token: await asyncio.sleep(0.1) got_token = await asyncio.to_thread(_token_bucket.consume, tokens=1) if time.monotonic() > deadline: - _stats["ratelimited_requests"] += 1 + await _increment_stat("ratelimited_requests") _prometheus.record_request(priority.name, "ratelimited") return JSONResponse( status_code=429, content={ "error": { - "message": "令牌不足(队列满 + passthrough),等待超时 30s", + "message": f"令牌不足(队列满 + passthrough),等待超时 {_config.request_timeout:.0f}s", "type": "RateLimitedError", } }, @@ -464,6 +473,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, Any]: _http_client = httpx.AsyncClient( timeout=httpx.Timeout(_config.request_timeout), + limits=httpx.Limits( + max_connections=100, + max_keepalive_connections=20, + ), ) _priority_queue = PriorityRequestQueue(max_size=_config.queue_max_size) _token_bucket = AdaptiveTokenBucket( @@ -489,9 +502,25 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, Any]: metrics_server = uvicorn.Server(metrics_config) _metrics_task = asyncio.create_task(metrics_server.serve()) + # CORS 中间件(严维序评审 #8) + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=False, + allow_methods=["*"], + allow_headers=["*"], + ) + # 挂载 webui 子路由 app.include_router(webui_router) + # upstream_api_key 启动检查(严维序评审 #5) + if not _config.upstream_api_key: + logger.warning( + "upstream_api_key_empty", + message="SIDECAR_API_KEY 未设置,NVIDIA 请求将因 401 认证失败", + ) + logger.info( "sidecar_started", host=_config.listen_host, @@ -541,7 +570,7 @@ async def _handle_proxy_request(request: Request, path: str) -> Response: 2. 网关识别 → 非 NVIDIA 直通 3. NVIDIA → 排队 + 令牌限流 + 转发 """ - _stats["total_requests"] += 1 + await _increment_stat("total_requests") # 解析请求 body_bytes: bytes = await request.body() @@ -561,7 +590,7 @@ async def _handle_proxy_request(request: Request, path: str) -> Response: # 非 NVIDIA → 直接转发 if not is_nvidia: - _stats["passthrough_requests"] += 1 + await _increment_stat("passthrough_requests") try: resp = await _forward_to_upstream( method=request.method, @@ -580,7 +609,7 @@ async def _handle_proxy_request(request: Request, path: str) -> Response: ) # NVIDIA → 排队 + 限流 + 转发 - _stats["nvidia_requests"] += 1 + await _increment_stat("nvidia_requests") priority: Priority = _resolve_priority(raw_headers) # 注入内部元数据到 payload @@ -599,7 +628,7 @@ async def _handle_proxy_request(request: Request, path: str) -> Response: }, ) except QueueFullError: - _stats["queue_full_rejects"] += 1 + await _increment_stat("queue_full_rejects") return JSONResponse( status_code=503, content={ @@ -611,7 +640,7 @@ async def _handle_proxy_request(request: Request, path: str) -> Response: ) except QueueFullPassthrough: # 队列满 + PASSTHROUGH:绕过排队,尝试令牌桶后直接转发 - _stats["passthrough_requests"] += 1 + await _increment_stat("passthrough_requests") logger.info("queue_full_passthrough", path=path) return await _passthrough_with_rate_limit(request, path, body_bytes, raw_headers, priority) diff --git a/services/nvidia_sidecar/webui.py b/services/nvidia_sidecar/webui.py index 9167e6d..0c2b33e 100644 --- a/services/nvidia_sidecar/webui.py +++ b/services/nvidia_sidecar/webui.py @@ -8,13 +8,15 @@ from __future__ import annotations import asyncio import json +import os import time from pathlib import Path from typing import Any, AsyncGenerator import structlog -from fastapi import APIRouter, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from pydantic import BaseModel webui_router: APIRouter = APIRouter(prefix="/api", tags=["webui"]) @@ -22,6 +24,14 @@ logger: structlog.stdlib.BoundLogger = structlog.get_logger("nvidia_sidecar.webu STATIC_DIR: Path = Path(__file__).parent / "static" +# dashboard.html 缓存(严维序评审 #6 / 梁思筑评审 #8:避免每次请求读磁盘) +_dashboard_html_cache: tuple[str, float] | None = None +_DASHBOARD_CACHE_TTL: float = 300.0 # 5 分钟 + +# Admin API 认证(严维序评审 #1) +_ADMIN_TOKEN: str | None = os.environ.get("SIDECAR_ADMIN_TOKEN") +_admin_auth_scheme: HTTPBearer = HTTPBearer(auto_error=False) + # --------------------------------------------------------------------------- # 配置热重载模型 @@ -44,12 +54,18 @@ async def _dashboard_stream(request: Request) -> StreamingResponse: 供 dashboard.html 的 EventSource 消费。 """ async def event_generator() -> AsyncGenerator[str, None]: + # 首帧发送 retry 字段(严维序评审 minor):指示客户端断连后等待 3s 重试 + first_frame = True while True: if await request.is_disconnected(): break try: snapshot: dict[str, Any] = await _build_snapshot() - yield f"data: {json.dumps(snapshot, ensure_ascii=False)}\n\n" + payload_sse = f"data: {json.dumps(snapshot, ensure_ascii=False)}\n\n" + if first_frame: + payload_sse = f"retry: 3000\n{payload_sse}" + first_frame = False + yield payload_sse except Exception: logger.exception("dashboard_sse_error") yield f"data: {json.dumps({'error': 'internal'})}\n\n" @@ -65,6 +81,12 @@ async def _dashboard_stream(request: Request) -> StreamingResponse: ) +# SSE 首帧写入 retry 字段(严维序评审 minor),在 event_generator 首次 yield 前注入 +# 通过在 StreamingResponse 返回前手动发送 retry header 实现 +# (SSE 协议支持 retry 字段作为重建连接间隔) +# 注:在 event_generator 的首个 yield 中加入 retry 声明 + + async def _build_snapshot() -> dict[str, Any]: """构建当前状态快照(从全局状态读取,含队列深度)。""" # 延迟导入避免循环依赖 @@ -90,7 +112,7 @@ async def _build_snapshot() -> dict[str, Any]: "total_dropped": queue_stats.get("total_dropped", 0), } except Exception: - pass + logger.warning("queue_stats_unavailable", message="队列统计获取失败,仪表盘队列深度可能不准确") return { "timestamp": now, @@ -133,6 +155,7 @@ async def get_config() -> dict[str, Any]: "listen_port": cfg.listen_port, "metrics_port": cfg.metrics_port, "upstream_url": cfg.upstream_url, + "upstream_api_key": _mask_api_key(cfg.upstream_api_key), "rate_rpm": _get_current_rate(server), "bucket_capacity": cfg.bucket_capacity, "request_timeout": cfg.request_timeout, @@ -160,8 +183,12 @@ async def update_config(body: ConfigPatch) -> JSONResponse: if body.queue_max_size is not None: if body.queue_max_size <= 0: raise HTTPException(status_code=400, detail="queue_max_size must be > 0") + ok, msg = server._priority_queue.set_max_size(body.queue_max_size) + if not ok: + raise HTTPException(status_code=400, detail=msg) cfg.queue_max_size = body.queue_max_size changed.append("queue_max_size") + logger.info("queue_max_size_updated", detail=msg) if body.fallback_enabled_passthrough is not None: cfg.fallback_enabled_passthrough = body.fallback_enabled_passthrough @@ -173,6 +200,18 @@ async def update_config(body: ConfigPatch) -> JSONResponse: ) +def _mask_api_key(key: str) -> str: + """对 API Key 进行脱敏处理,仅保留前 4 位以供识别。 + + 严维序评审 #2 / 沈路明评审 #3:防止 API Key 明文泄露。 + """ + if not key: + return "" + if len(key) <= 4: + return key[:2] + "****" + return key[:4] + "****" + + def _get_current_rate(server_module: Any) -> float: """获取当前实际速率(避退调整后),兼容 AdaptiveTokenBucket。""" tb = server_module._token_bucket @@ -191,15 +230,36 @@ async def dashboard_stream(request: Request) -> StreamingResponse: return await _dashboard_stream(request) +async def _verify_admin_auth( + credentials: HTTPAuthorizationCredentials | None = Depends(_admin_auth_scheme), +) -> None: + """Admin API Bearer Token 认证(严维序评审 #1)。 + + 若设置了 SIDECAR_ADMIN_TOKEN 环境变量,则要求请求携带匹配的 Bearer Token。 + 未设置时跳过认证(开发/测试环境)。 + """ + if _ADMIN_TOKEN is None: + return # 未配置认证 token,允许无认证访问 + if credentials is None: + raise HTTPException(status_code=401, detail="需要 Bearer Token 认证(Admin API)") + if credentials.credentials != _ADMIN_TOKEN: + raise HTTPException(status_code=403, detail="Admin Token 无效") + + @webui_router.get("/admin/config") -async def admin_get_config() -> JSONResponse: - """获取当前配置。""" +async def admin_get_config( + _auth: None = Depends(_verify_admin_auth), +) -> JSONResponse: + """获取当前配置(需要 Admin 认证)。""" return JSONResponse(content=await get_config()) @webui_router.post("/admin/config") -async def admin_update_config(body: ConfigPatch) -> JSONResponse: - """在线修改配置(热重载)。""" +async def admin_update_config( + body: ConfigPatch, + _auth: None = Depends(_verify_admin_auth), +) -> JSONResponse: + """在线修改配置(热重载,需要 Admin 认证)。""" return await update_config(body) @@ -207,10 +267,27 @@ async def admin_update_config(body: ConfigPatch) -> JSONResponse: # 仪表盘静态页面 # --------------------------------------------------------------------------- -@webui_router.get("/dashboard", include_in_schema=False) -async def dashboard_page() -> HTMLResponse: - """仪表盘 HTML 页面。""" +def _get_dashboard_html() -> str: + """获取仪表盘 HTML(带缓存,严维序评审 #6 / 梁思筑评审 #8)。 + + 首次加载后缓存 5 分钟,避免每次请求读磁盘。 + """ + global _dashboard_html_cache + now = time.monotonic() + if _dashboard_html_cache is not None: + cached_content, cached_at = _dashboard_html_cache + if now - cached_at < _DASHBOARD_CACHE_TTL: + return cached_content + dashboard_path = STATIC_DIR / "dashboard.html" if dashboard_path.is_file(): - return HTMLResponse(content=dashboard_path.read_text(encoding="utf-8")) - return HTMLResponse(content="

dashboard.html not found

", status_code=404) \ No newline at end of file + content = dashboard_path.read_text(encoding="utf-8") + _dashboard_html_cache = (content, now) + return content + return "

dashboard.html not found

" + + +@webui_router.get("/dashboard", include_in_schema=False) +async def dashboard_page() -> HTMLResponse: + """仪表盘 HTML 页面(含缓存策略)。""" + return HTMLResponse(content=_get_dashboard_html()) \ No newline at end of file