diff --git a/services/nvidia_sidecar/server.py b/services/nvidia_sidecar/server.py index bd932e9..e641c0f 100644 --- a/services/nvidia_sidecar/server.py +++ b/services/nvidia_sidecar/server.py @@ -12,6 +12,8 @@ BIZ-46 Phase3: 架构解耦 — 所有全局状态收敛为 SidecarContext (§1) from __future__ import annotations import asyncio +import json + import logging import time from collections.abc import AsyncGenerator @@ -23,7 +25,7 @@ import structlog import uvicorn from fastapi import Depends, FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse from nvidia_sidecar.config import load_config, SidecarConfig from nvidia_sidecar.context import SidecarContext @@ -59,7 +61,7 @@ structlog.configure( structlog.processors.JSONRenderer(), ], context_class=dict, - logger_factory=structlog.PrintLoggerFactory(), + logger_factory=structlog.stdlib.LoggerFactory(), wrapper_class=structlog.stdlib.BoundLogger, cache_logger_on_first_use=True, ) @@ -70,9 +72,9 @@ logger: structlog.stdlib.BoundLogger = structlog.get_logger("nvidia_sidecar") # FastAPI 依赖注入 # --------------------------------------------------------------------------- -def get_context(request: Request) -> SidecarContext: +def get_context() -> SidecarContext: """从 app.state 获取 SidecarContext(FastAPI 依赖注入)。""" - return request.app.state.sidecar # type: ignore[no-any-return] + return app.state.sidecar # type: ignore[no-any-return] # --------------------------------------------------------------------------- @@ -137,7 +139,12 @@ async def _forward_to_upstream( Raises: httpx.HTTPError: HTTP 请求失败。 """ - upstream_url = ctx.config.upstream_url.rstrip("/") + path + # 构建上游 URL:如果 upstream_url 已经包含 /v1 路径,则避免路径重复 + base_url = ctx.config.upstream_url.rstrip("/") + if base_url.endswith("/v1") and path.startswith("/v1"): + upstream_url = base_url + path[3:] # 去掉 path 中的 /v1 前缀 + else: + upstream_url = base_url + path forward_headers: dict[str, str] = { k: v for k, v in headers.items() if k.lower() not in ("host", "content-length", "transfer-encoding") @@ -489,28 +496,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, Any]: # 启动 worker 协程 worker_task = asyncio.create_task(_worker_loop(ctx)) - # 在独立端口 :9191 启动 Prometheus metrics 服务器 - metrics_app = prometheus.build_asgi_app() - metrics_config = uvicorn.Config( - metrics_app, - host=config.listen_host, - port=config.metrics_port, - log_level="error", - ) - metrics_server = uvicorn.Server(metrics_config) - _metrics_task = asyncio.create_task(metrics_server.serve()) + # Metrics 通过主服务器 `/metrics` 端点提供 - # CORS 中间件(严维序评审 #8) - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=False, - allow_methods=["*"], - allow_headers=["*"], - ) - - # 挂载 webui 子路由 - app.include_router(webui_router) + # webui 路由(暂停挂载,排查路由匹配问题) + # app.include_router(webui_router) # upstream_api_key 启动检查(严维序评审 #5) if not config.upstream_api_key: @@ -538,16 +527,26 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, Any]: except asyncio.CancelledError: pass - _metrics_task.cancel() - try: - await _metrics_task - except asyncio.CancelledError: - pass - await http_client.aclose() logger.info("sidecar_stopped") +app: FastAPI = FastAPI( + title="NVIDIA Sidecar Rate-Limiting Proxy", + version="0.1.0", + lifespan=lifespan, +) + +# CORS 中间件(在 lifespan 前添加,避免 RuntimeError) +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=False, + allow_methods=["*"], + allow_headers=["*"], +) + + def _mask_api_key(key: str) -> str: """对 API Key 进行脱敏处理,仅保留前 4 位以供识别。""" if not key: @@ -557,13 +556,6 @@ def _mask_api_key(key: str) -> str: return key[:4] + "****" -app: FastAPI = FastAPI( - title="NVIDIA Sidecar Rate-Limiting Proxy", - version="0.1.0", - lifespan=lifespan, -) - - # --------------------------------------------------------------------------- # 核心代理处理器 # --------------------------------------------------------------------------- @@ -621,7 +613,13 @@ async def _handle_proxy_request(ctx: SidecarContext, request: Request, path: str # 注入内部元数据到 payload payload_for_queue: dict[str, Any] = dict(body_json) - payload_for_queue["_raw_body"] = body_bytes + # 剥离 NVIDIA provider 前缀(如 "nvidia/deepseek-ai/deepseek-v4-pro" → "deepseek-ai/deepseek-v4-pro") + if model and "/" in model: + stripped_model: str = model.split("/", 1)[1] + payload_for_queue["model"] = stripped_model + bytes_model_stripped: bytes = json.dumps(body_json).encode() + # Update model in the raw body bytes + payload_for_queue["_raw_body"] = json.dumps(payload_for_queue).encode() # 尝试入队;PASSTHROUGH 策略下队列满时走直通路径 try: @@ -768,26 +766,30 @@ async def status(ctx: SidecarContext = Depends(get_context)) -> dict[str, Any]: # ---- OpenAI 兼容端点 ---- @app.post("/v1/chat/completions") -async def chat_completions(request: Request, ctx: SidecarContext = Depends(get_context)) -> Response: +async def chat_completions(request: Request) -> Response: """OpenAI Chat Completions API 代理(含流式支持)。""" + ctx: SidecarContext = get_context() return await _handle_proxy_request(ctx, request, "/v1/chat/completions") @app.post("/v1/completions") -async def completions(request: Request, ctx: SidecarContext = Depends(get_context)) -> Response: +async def completions(request: Request) -> Response: + ctx: SidecarContext = get_context() """OpenAI Completions API 代理(legacy)。""" return await _handle_proxy_request(ctx, request, "/v1/completions") @app.post("/v1/embeddings") -async def embeddings(request: Request, ctx: SidecarContext = Depends(get_context)) -> Response: +async def embeddings(request: Request) -> Response: + ctx: SidecarContext = get_context() """OpenAI Embeddings API 代理。""" return await _handle_proxy_request(ctx, request, "/v1/embeddings") @app.get("/v1/models") @app.get("/v1/models/{model_id:path}") -async def list_models(request: Request, model_id: str | None = None, ctx: SidecarContext = Depends(get_context)) -> Response: +async def list_models(request: Request, model_id: str | None = None) -> Response: + ctx: SidecarContext = get_context() """OpenAI Models API 代理。""" path = f"/v1/models/{model_id}" if model_id else "/v1/models" return await _handle_proxy_request(ctx, request, path) @@ -796,12 +798,22 @@ async def list_models(request: Request, model_id: str | None = None, ctx: Sideca # ---- 通用代理(catch-all 用于非标准 NVIDIA 端点) ---- @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) -async def catch_all(request: Request, path: str, ctx: SidecarContext = Depends(get_context)) -> Response: +async def catch_all(request: Request, path: str) -> Response: + ctx: SidecarContext = get_context() """通用代理端点:转发任何未匹配的路径到上游。""" target_path = f"/{path}" if not path.startswith("/") else path return await _handle_proxy_request(ctx, request, target_path) +@app.get("/metrics") +async def metrics(ctx: SidecarContext = Depends(get_context)) -> PlainTextResponse: + """Prometheus 指标端点。""" + return PlainTextResponse( + content=ctx.prometheus.generate_latest().decode(), + media_type="text/plain; version=0.0.4", + ) + + # --------------------------------------------------------------------------- # 入口 # ---------------------------------------------------------------------------