Files
EnterpriseArchitect/services/nvidia_sidecar/server.py
T
vincent 1513abbca8 fix(BIZ-41): Phase0 关键修复 — NVIDIA 模型前缀剥离 + URL 路径重复修复
- 模型前缀剥离:_handle_proxy_request 中 NVIDIA 模型去掉 provider 前缀
  (如 "nvidia/deepseek-ai/deepseek-v4-flash" → "deepseek-ai/deepseek-v4-flash")
- URL 路径重复修复:_forward_to_upstream 检查 upstream_url 是否已包含 /v1
  若包含则从 path 中去掉重复的 /v1 前缀
- structlog: PrintLoggerFactory→LoggerFactory
- CORS: 移出 lifespan,在 app 创建后添加
- Metrics: 移除子进程 uvicorn server,改为 @app.get("/metrics") 路由
- Worker: catch-all 异常处理增加 pending future 清理

Co-authored-by: multica-agent <github@multica.ai>
2026-06-25 00:26:59 +08:00

834 lines
30 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
NVIDIA Sidecar 限流代理 — FastAPI 代理主入口 (§3.4)
完整的 API 代理链路:
接收 → 网关识别 → [NVIDIA: 排队 → 令牌限流] → httpx 转发 → 返回
非 NVIDIA 请求直通上游,NVIDIA 请求经过四级优先级队列 + 令牌桶限流。
BIZ-46 Phase3: 架构解耦 — 所有全局状态收敛为 SidecarContext (§1)
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any
import httpx
import structlog
import uvicorn
from fastapi import Depends, FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse
from nvidia_sidecar.config import load_config, SidecarConfig
from nvidia_sidecar.context import SidecarContext
from nvidia_sidecar.rate_limiter import (
Priority,
AdaptiveTokenBucket,
is_nvidia_gateway,
)
from nvidia_sidecar.priority_queue import (
PriorityRequestQueue,
QueueFullError,
QueueFullPassthrough,
QueueFullPolicy,
)
from nvidia_sidecar.metrics import PrometheusMetrics
from nvidia_sidecar.health import HealthService
from nvidia_sidecar.webui import webui_router
# ---------------------------------------------------------------------------
# 结构化日志
# ---------------------------------------------------------------------------
structlog.configure(
processors=[
structlog.stdlib.filter_by_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
structlog.processors.JSONRenderer(),
],
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
logger: structlog.stdlib.BoundLogger = structlog.get_logger("nvidia_sidecar")
# ---------------------------------------------------------------------------
# FastAPI 依赖注入
# ---------------------------------------------------------------------------
def get_context() -> SidecarContext:
"""从 app.state 获取 SidecarContextFastAPI 依赖注入)。"""
return app.state.sidecar # type: ignore[no-any-return]
# ---------------------------------------------------------------------------
# 工具函数
# ---------------------------------------------------------------------------
def _extract_model(body: Any) -> str | None:
"""从请求体中提取模型标识符(兼容 OpenAI Chat/Completions 格式)。
Args:
body: 已解析的 JSON 请求体。
Returns:
模型标识符字符串,或 None。
"""
if isinstance(body, dict):
return str(body.get("model", "")) or None
return None
def _resolve_priority(headers: dict[str, str]) -> Priority:
"""从请求 headers 解析优先级。
检查 ``X-Priority`` header,值为 ``urgent``/``high``/``normal``/``low``
不区分大小写。默认 NORMAL。
"""
raw = headers.get("x-priority", "").strip().lower()
mapping: dict[str, Priority] = {
"urgent": Priority.URGENT,
"high": Priority.HIGH,
"normal": Priority.NORMAL,
"low": Priority.LOW,
}
return mapping.get(raw, Priority.NORMAL)
# ---------------------------------------------------------------------------
# 上游转发
# ---------------------------------------------------------------------------
async def _forward_to_upstream(
ctx: SidecarContext,
method: str,
path: str,
body: bytes | None,
headers: dict[str, str],
stream: bool = False,
) -> httpx.Response:
"""将请求转发到 NVIDIA 上游 API。
Args:
ctx: SidecarContext 运行时上下文。
method: HTTP 方法。
path: 请求路径(如 ``/v1/chat/completions``)。
body: 原始请求体 bytes。
headers: 要转发的请求 headers(会追加 Authorization)。
stream: 是否请求流式响应。
Returns:
httpx.Response 对象。
Raises:
httpx.HTTPError: HTTP 请求失败。
"""
# 构建上游 URL:如果 upstream_url 已经包含 /v1 路径,则避免路径重复
base_url = ctx.config.upstream_url.rstrip("/")
if base_url.endswith("/v1") and path.startswith("/v1"):
upstream_url = base_url + path[3:] # 去掉 path 中的 /v1 前缀
else:
upstream_url = base_url + path
forward_headers: dict[str, str] = {
k: v for k, v in headers.items()
if k.lower() not in ("host", "content-length", "transfer-encoding")
}
if ctx.config.upstream_api_key:
forward_headers["authorization"] = f"Bearer {ctx.config.upstream_api_key}"
elif "authorization" not in {k.lower() for k in forward_headers}:
forward_headers["authorization"] = "Bearer nvidia"
try:
req = ctx.http_client.build_request(
method=method,
url=upstream_url,
headers=forward_headers,
content=body,
timeout=ctx.config.request_timeout,
)
response = await ctx.http_client.send(req, stream=stream)
return response
except httpx.TimeoutException:
logger.warning("upstream_timeout", path=path, timeout=ctx.config.request_timeout)
raise
except httpx.HTTPError as exc:
logger.error("upstream_error", path=path, error=str(exc))
raise
# ---------------------------------------------------------------------------
# worker 协程:消费优先级队列 + 令牌桶 + 转发
# ---------------------------------------------------------------------------
async def _worker_loop(ctx: SidecarContext) -> None:
"""后台 worker:持续从优先级队列取请求 → 令牌限流 → 转发 → 设置 future 结果。
Args:
ctx: SidecarContext 运行时上下文。
"""
log = logger.bind(worker="main")
log.info("worker_started")
while True:
try:
queue_item = await ctx.priority_queue.get(timeout=1.0)
if queue_item is None:
continue
request_id = queue_item.request_id
payload = queue_item.payload
headers = queue_item.headers
enqueued_at = queue_item.enqueued_at
# 查找对应的 pending future
pending_entry = ctx.pending_requests.get(request_id)
if pending_entry is None:
log.warning("orphan_request", request_id=request_id)
continue
future, _ = pending_entry
# 低优先级令牌等待超时处理
if queue_item.priority == Priority.LOW:
# 放线程池执行阻塞的令牌桶调用
got_token = await asyncio.to_thread(
ctx.token_bucket.try_consume,
tokens=1,
timeout=ctx.config.low_priority_timeout,
)
if not got_token:
log.info("low_priority_timeout", request_id=request_id)
await ctx.increment_stat("ratelimited_requests")
ctx.prometheus.record_request(queue_item.priority.name, "ratelimited")
if not future.done():
future.set_exception(
_RateLimitedError(
f"低优先级请求令牌等待超时 ({ctx.config.low_priority_timeout}s)"
)
)
ctx.pending_requests.pop(request_id, None)
continue
else:
# 非低优先级:在 worker 内轮询等待令牌,避免重入队导致 future 悬挂
got_token = await asyncio.to_thread(ctx.token_bucket.consume, tokens=1)
if not got_token:
token_deadline = time.monotonic() + ctx.config.request_timeout
while not got_token:
await asyncio.sleep(0.1)
got_token = await asyncio.to_thread(ctx.token_bucket.consume, tokens=1)
if time.monotonic() > token_deadline:
break
if not got_token:
log.warning(
"token_wait_timeout",
request_id=request_id,
priority=queue_item.priority.name,
timeout=ctx.config.request_timeout,
)
await ctx.increment_stat("ratelimited_requests")
ctx.prometheus.record_request(queue_item.priority.name, "ratelimited")
if not future.done():
future.set_exception(
_RateLimitedError(
f"令牌等待超时 ({ctx.config.request_timeout:.0f}s)"
)
)
ctx.pending_requests.pop(request_id, None)
continue
# 转发到上游
upstream_start = time.monotonic()
try:
path = headers.get("x-original-path", "/v1/chat/completions")
method = headers.get("x-original-method", "POST")
# 过滤内部 headers
clean_headers = {
k: v for k, v in headers.items()
if not k.startswith("x-original-") and not k.startswith("x-request-id")
}
resp = await _forward_to_upstream(
ctx=ctx,
method=method,
path=path,
body=payload.get("_raw_body"),
headers=clean_headers,
stream=payload.get("stream", False),
)
upstream_latency = time.monotonic() - upstream_start
queue_latency = time.monotonic() - enqueued_at
total_latency = upstream_latency + queue_latency
is_429: bool = resp.status_code == 429
ctx.token_bucket.record_response(is_429)
# 避退状态评估 + 指标更新
ctx.token_bucket.evaluate_retreat()
retreat_state = ctx.token_bucket.get_retreat_state()
effective_rpm = ctx.token_bucket.get_effective_rate_rpm()
upstream_429_rate = ctx.token_bucket.get_429_rate()
ctx.prometheus.update_retreat_metrics(retreat_state, effective_rpm, upstream_429_rate)
# 模型级信息写入 JSON 日志 (BIZ-46 Phase3: provider label 收敛后保留)
model_id = _extract_model(payload) or "unknown"
log.info(
"request_completed",
request_id=request_id,
status=resp.status_code,
model_id=model_id,
upstream_latency=round(upstream_latency, 3),
queue_latency=round(queue_latency, 3),
total_latency=round(total_latency, 3),
retreat_state=retreat_state,
effective_rpm=round(effective_rpm, 1),
)
# 记录 Prometheus 指标 — provider 收敛(BIZ-46 Phase3
provider = "nvidia"
ctx.prometheus.record_upstream_latency(provider, upstream_latency)
if not resp.is_success:
ctx.prometheus.record_upstream_error(resp.status_code, provider)
ctx.prometheus.record_request(queue_item.priority.name, "success" if resp.is_success else "error")
ctx.prometheus.record_queue_latency(queue_item.priority.name, queue_latency)
if not future.done():
future.set_result(resp)
except (httpx.HTTPError, OSError) as exc:
log.error("upstream_request_failed", request_id=request_id, error=str(exc))
await ctx.increment_stat("upstream_errors")
ctx.prometheus.record_request(queue_item.priority.name, "error")
ctx.prometheus.set_health(False)
if not future.done():
future.set_exception(exc)
ctx.pending_requests.pop(request_id, None)
except asyncio.CancelledError:
log.info("worker_cancelled")
break
except Exception:
log.exception("worker_unexpected_error")
# ---------------------------------------------------------------------------
# PASSTHROUGH 直通路径(队列满 + PASSTHROUGH 策略)
# ---------------------------------------------------------------------------
async def _passthrough_with_rate_limit(
ctx: SidecarContext,
request: Request,
path: str,
body_bytes: bytes,
raw_headers: dict[str, str],
priority: Priority,
) -> Response:
"""队列满时的 PASSSTHROUGH 直通路径:仍受令牌桶限流,但不排队。
Args:
ctx: SidecarContext 运行时上下文。
request: FastAPI Request。
path: 请求路径。
body_bytes: 原始请求体。
raw_headers: 请求 headers。
priority: 请求优先级。
Returns:
FastAPI Response。
"""
await ctx.increment_stat("passthrough_requests")
ctx.prometheus.increment_fallback()
# 低优先级走令牌桶等待
if priority == Priority.LOW:
got_token = await asyncio.to_thread(
ctx.token_bucket.try_consume,
tokens=1,
timeout=ctx.config.low_priority_timeout,
)
if not got_token:
await ctx.increment_stat("ratelimited_requests")
ctx.prometheus.record_request(priority.name, "ratelimited")
return JSONResponse(
status_code=429,
content={
"error": {
"message": f"令牌不足(队列满 + passthrough),超时 {ctx.config.low_priority_timeout}s",
"type": "RateLimitedError",
}
},
)
else:
got_token = await asyncio.to_thread(ctx.token_bucket.consume, tokens=1)
if not got_token:
deadline = time.monotonic() + ctx.config.request_timeout
while not got_token:
await asyncio.sleep(0.1)
got_token = await asyncio.to_thread(ctx.token_bucket.consume, tokens=1)
if time.monotonic() > deadline:
await ctx.increment_stat("ratelimited_requests")
ctx.prometheus.record_request(priority.name, "ratelimited")
return JSONResponse(
status_code=429,
content={
"error": {
"message": f"令牌不足(队列满 + passthrough),等待超时 {ctx.config.request_timeout:.0f}s",
"type": "RateLimitedError",
}
},
)
# 拿到令牌,直接转发
try:
clean_headers = {k: v for k, v in raw_headers.items()}
resp = await _forward_to_upstream(
ctx=ctx,
method=request.method,
path=path,
body=body_bytes if body_bytes else None,
headers=clean_headers,
stream=False,
)
retreat_state = ctx.token_bucket.get_retreat_state()
ctx.token_bucket.evaluate_retreat()
ctx.prometheus.update_retreat_metrics(
retreat_state,
ctx.token_bucket.get_effective_rate_rpm(),
ctx.token_bucket.get_429_rate(),
)
return _build_response(resp)
except Exception as exc:
status, msg = _map_exception(exc)
logger.error("passthrough_error", path=path, error=str(exc))
ctx.prometheus.set_health(False)
return JSONResponse(
status_code=status,
content={"error": {"message": msg, "type": type(exc).__name__}},
)
# ---------------------------------------------------------------------------
# 自定义异常
# ---------------------------------------------------------------------------
class _RateLimitedError(Exception):
"""429 限流错误。"""
pass
# ---------------------------------------------------------------------------
# 异常处理矩阵 (§3.4)
# ---------------------------------------------------------------------------
_EXCEPTION_MATRIX: dict[type[Exception], tuple[int, str]] = {
_RateLimitedError: (429, "Too Many Requests — 令牌不足"),
QueueFullError: (503, "Service Unavailable — 队列已满"),
httpx.TimeoutException: (504, "Gateway Timeout — 上游超时"),
httpx.ConnectError: (502, "Bad Gateway — 上游连接失败"),
httpx.HTTPStatusError: (502, "Bad Gateway — 上游返回错误状态"),
}
def _map_exception(exc: Exception) -> tuple[int, str]:
"""将异常映射为 HTTP 状态码 + 错误信息。"""
for exc_type, (status, msg) in _EXCEPTION_MATRIX.items():
if isinstance(exc, exc_type):
return status, msg
return 500, f"Internal Server Error — {type(exc).__name__}"
# ---------------------------------------------------------------------------
# FastAPI 应用 + lifespan
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, Any]:
"""应用生命周期管理:初始化/清理全局资源。
BIZ-46 Phase3: 所有资源收敛到 SidecarContext,挂载于 app.state.sidecar。
"""
# 启动
config: SidecarConfig = load_config()
logging.getLogger().setLevel(config.log_level.upper())
http_client: httpx.AsyncClient = httpx.AsyncClient(
timeout=httpx.Timeout(config.request_timeout),
limits=httpx.Limits(
max_connections=100,
max_keepalive_connections=20,
),
)
priority_queue: PriorityRequestQueue = PriorityRequestQueue(max_size=config.queue_max_size)
token_bucket: AdaptiveTokenBucket = AdaptiveTokenBucket(
rate=config.rate_rpm / 60.0,
capacity=config.bucket_capacity,
)
prometheus: PrometheusMetrics = PrometheusMetrics()
health: HealthService = HealthService()
ctx: SidecarContext = SidecarContext(
config=config,
http_client=http_client,
token_bucket=token_bucket,
priority_queue=priority_queue,
prometheus=prometheus,
health=health,
)
ctx.stats["start_time"] = int(time.time())
app.state.sidecar = ctx # 注入 FastAPI
# 启动 worker 协程
worker_task = asyncio.create_task(_worker_loop(ctx))
# Metrics 通过主服务器 `/metrics` 端点提供
# webui 路由(暂停挂载,排查路由匹配问题)
# app.include_router(webui_router)
# upstream_api_key 启动检查(严维序评审 #5)
if not config.upstream_api_key:
logger.warning(
"upstream_api_key_empty",
message="SIDECAR_API_KEY 未设置,NVIDIA 请求将因 401 认证失败",
)
logger.info(
"sidecar_started",
host=config.listen_host,
port=config.listen_port,
metrics_port=config.metrics_port,
rate_rpm=config.rate_rpm,
queue_max=config.queue_max_size,
retreat_enabled=True,
)
yield # app 运行中
# 关闭
worker_task.cancel()
try:
await worker_task
except asyncio.CancelledError:
pass
await http_client.aclose()
logger.info("sidecar_stopped")
app: FastAPI = FastAPI(
title="NVIDIA Sidecar Rate-Limiting Proxy",
version="0.1.0",
lifespan=lifespan,
)
# CORS 中间件(在 lifespan 前添加,避免 RuntimeError
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
def _mask_api_key(key: str) -> str:
"""对 API Key 进行脱敏处理,仅保留前 4 位以供识别。"""
if not key:
return ""
if len(key) <= 4:
return key[:2] + "****"
return key[:4] + "****"
# ---------------------------------------------------------------------------
# 核心代理处理器
# ---------------------------------------------------------------------------
async def _handle_proxy_request(ctx: SidecarContext, request: Request, path: str) -> Response:
"""统一的代理请求处理入口。
执行完整链路:
1. 解析请求体 → 提取 model
2. 网关识别 → 非 NVIDIA 直通
3. NVIDIA → 排队 + 令牌限流 + 转发
"""
await ctx.increment_stat("total_requests")
# 解析请求
body_bytes: bytes = await request.body()
raw_headers: dict[str, str] = dict(request.headers)
# 尝试解析 JSON body
body_json: dict[str, Any] = {}
try:
if body_bytes:
body_json = __import__("json").loads(body_bytes)
except (ValueError, TypeError):
body_json = {}
# 提取 model 进行网关识别
model: str | None = _extract_model(body_json)
is_nvidia: bool = is_nvidia_gateway(model)
# 非 NVIDIA → 直接转发
if not is_nvidia:
await ctx.increment_stat("passthrough_requests")
try:
resp = await _forward_to_upstream(
ctx=ctx,
method=request.method,
path=path,
body=body_bytes if body_bytes else None,
headers=raw_headers,
stream=body_json.get("stream", False),
)
return _build_response(resp)
except Exception as exc:
status, msg = _map_exception(exc)
logger.error("passthrough_error", path=path, error=str(exc))
return JSONResponse(
status_code=status,
content={"error": {"message": msg, "type": type(exc).__name__}},
)
# NVIDIA → 排队 + 限流 + 转发
await ctx.increment_stat("nvidia_requests")
priority: Priority = _resolve_priority(raw_headers)
# 注入内部元数据到 payload
payload_for_queue: dict[str, Any] = dict(body_json)
# 剥离 NVIDIA provider 前缀(如 "nvidia/deepseek-ai/deepseek-v4-pro" → "deepseek-ai/deepseek-v4-pro"
if model and "/" in model:
stripped_model: str = model.split("/", 1)[1]
payload_for_queue["model"] = stripped_model
bytes_model_stripped: bytes = json.dumps(body_json).encode()
# Update model in the raw body bytes
payload_for_queue["_raw_body"] = json.dumps(payload_for_queue).encode()
# 尝试入队;PASSTHROUGH 策略下队列满时走直通路径
try:
request_id = await ctx.priority_queue.put(
item=payload_for_queue,
priority=priority,
headers={
**raw_headers,
"x-original-path": path,
"x-original-method": request.method,
},
)
except QueueFullError:
await ctx.increment_stat("queue_full_rejects")
return JSONResponse(
status_code=503,
content={
"error": {
"message": "队列已满,当前策略: reject",
"type": "QueueFullError",
}
},
)
except QueueFullPassthrough:
await ctx.increment_stat("passthrough_requests")
logger.info("queue_full_passthrough", path=path)
return await _passthrough_with_rate_limit(ctx, request, path, body_bytes, raw_headers, priority)
# 创建 future 并注册到 pending
loop = asyncio.get_running_loop()
future: asyncio.Future[httpx.Response] = loop.create_future()
ctx.pending_requests[request_id] = (future, time.monotonic())
try:
resp = await future
return _build_response(resp)
except _RateLimitedError as exc:
return JSONResponse(
status_code=429,
content={
"error": {
"message": str(exc),
"type": "RateLimitedError",
}
},
)
except Exception as exc:
status, msg = _map_exception(exc)
logger.error("proxy_error", path=path, request_id=request_id, error=str(exc))
return JSONResponse(
status_code=status,
content={"error": {"message": msg, "type": type(exc).__name__}},
)
def _build_response(resp: httpx.Response) -> Response:
"""将 httpx.Response 转换为 FastAPI Response。
支持 JSON 和流式 (SSE) 两种响应类型。
"""
content_type = resp.headers.get("content-type", "")
# 流式响应 (SSE)
if "text/event-stream" in content_type or "stream" in content_type:
return StreamingResponse(
content=resp.aiter_bytes(),
status_code=resp.status_code,
headers={
k: v for k, v in resp.headers.items()
if k.lower() not in ("content-encoding", "transfer-encoding")
},
media_type="text/event-stream",
)
# 普通 JSON 响应
return Response(
content=resp.content,
status_code=resp.status_code,
headers={
k: v for k, v in resp.headers.items()
if k.lower() not in ("content-encoding", "transfer-encoding")
},
media_type=content_type or "application/json",
)
# ---------------------------------------------------------------------------
# 路由
# ---------------------------------------------------------------------------
@app.get("/health")
async def health(ctx: SidecarContext = Depends(get_context)) -> dict[str, Any]:
"""存活检查 (liveness)。"""
return ctx.health.liveness()
@app.get("/health/ready")
async def health_ready(ctx: SidecarContext = Depends(get_context)) -> dict[str, Any]:
"""就绪检查 (readiness),含上游连通性。
BIZ-46 Phase3: 复用 ctx.http_client,不再每次创建新 client。
"""
queue_size = await ctx.priority_queue.get_queue_size()
bucket_status = ctx.token_bucket.get_status()
return await ctx.health.readiness(
upstream_url=ctx.config.upstream_url,
upstream_api_key=ctx.config.upstream_api_key or "",
queue_current_size=queue_size,
queue_max_size=ctx.config.queue_max_size,
available_tokens=bucket_status["tokens"],
bucket_capacity=bucket_status["capacity"],
http_client=ctx.http_client, # 复用主 client
)
@app.get("/status")
async def status(ctx: SidecarContext = Depends(get_context)) -> dict[str, Any]:
"""调试用:限流器 + 队列 + 避退完整状态。"""
queue_stats = await ctx.priority_queue.get_stats()
bucket_status = ctx.token_bucket.get_status()
return {
"requests": {
"total": ctx.stats["total_requests"],
"nvidia": ctx.stats["nvidia_requests"],
"passthrough": ctx.stats["passthrough_requests"],
"ratelimited": ctx.stats["ratelimited_requests"],
},
"errors": {
"queue_full_rejects": ctx.stats["queue_full_rejects"],
"upstream_errors": ctx.stats["upstream_errors"],
},
"queue": queue_stats,
"token_bucket": bucket_status,
"retreat": {
"state": ctx.token_bucket.get_retreat_state(),
"effective_rpm": round(ctx.token_bucket.get_effective_rate_rpm(), 1),
"base_rpm": round(ctx.token_bucket.get_base_rate_rpm(), 1),
"upstream_429_rate": round(ctx.token_bucket.get_429_rate(), 4),
},
"uptime_seconds": ctx.uptime_seconds,
}
# ---- OpenAI 兼容端点 ----
@app.post("/v1/chat/completions")
async def chat_completions(request: Request) -> Response:
"""OpenAI Chat Completions API 代理(含流式支持)。"""
ctx: SidecarContext = get_context()
return await _handle_proxy_request(ctx, request, "/v1/chat/completions")
@app.post("/v1/completions")
async def completions(request: Request) -> Response:
ctx: SidecarContext = get_context()
"""OpenAI Completions API 代理(legacy)。"""
return await _handle_proxy_request(ctx, request, "/v1/completions")
@app.post("/v1/embeddings")
async def embeddings(request: Request) -> Response:
ctx: SidecarContext = get_context()
"""OpenAI Embeddings API 代理。"""
return await _handle_proxy_request(ctx, request, "/v1/embeddings")
@app.get("/v1/models")
@app.get("/v1/models/{model_id:path}")
async def list_models(request: Request, model_id: str | None = None) -> Response:
ctx: SidecarContext = get_context()
"""OpenAI Models API 代理。"""
path = f"/v1/models/{model_id}" if model_id else "/v1/models"
return await _handle_proxy_request(ctx, request, path)
# ---- 通用代理(catch-all 用于非标准 NVIDIA 端点) ----
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
async def catch_all(request: Request, path: str) -> Response:
ctx: SidecarContext = get_context()
"""通用代理端点:转发任何未匹配的路径到上游。"""
target_path = f"/{path}" if not path.startswith("/") else path
return await _handle_proxy_request(ctx, request, target_path)
@app.get("/metrics")
async def metrics(ctx: SidecarContext = Depends(get_context)) -> PlainTextResponse:
"""Prometheus 指标端点。"""
return PlainTextResponse(
content=ctx.prometheus.generate_latest().decode(),
media_type="text/plain; version=0.0.4",
)
# ---------------------------------------------------------------------------
# 入口
# ---------------------------------------------------------------------------
def main() -> None:
"""开发/调试入口。"""
import uvicorn
cfg: SidecarConfig = load_config()
uvicorn.run(
"nvidia_sidecar.server:app",
host=cfg.listen_host,
port=cfg.listen_port,
log_level=cfg.log_level.lower(),
)
if __name__ == "__main__":
main()