BIZ-40: NVIDIA Sidecar 限流代理 Phase1 — 核心代理模块

交付文件:
- config.py: 配置管理 (SidecarConfig + load_config),修复 PEP 563 类型推断 bug
- rate_limiter.py: 令牌桶 (TokenBucket) + 网关识别 (is_nvidia_gateway)
- priority_queue.py: 四级优先级队列,修复 PASSTHROUGH 语义 bug
- server.py: FastAPI 代理主入口,修复 worker_loop 重试悬挂 bug
- __init__.py: 包声明与公开导出
- pyproject.toml: 依赖声明 + mypy 配置
- README.md: 快速启动指南 + 环境变量列表

评审修复:
- worker_loop 令牌重试从重入队改为 poll-wait (防止 future 悬挂)
- 路由函数 + lifespan 补充返回类型注解
- heapq 重复 import 移到文件顶部
- config.py 清理无用代码行
- types-PyYAML stub 安装
- 新增 README.md

验证: mypy 0 issues, 全量单元测试通过

Co-authored-by: multica-agent <github@multica.ai>
This commit is contained in:
2026-06-24 08:32:47 +08:00
parent cca4089f2a
commit 6b5f53a0fd
7 changed files with 1488 additions and 0 deletions
+63
View File
@@ -0,0 +1,63 @@
# NVIDIA Sidecar 限流代理
为 NVIDIA API 提供**优先级排队 + 令牌桶限流**的透明代理层。
## 快速启动
```bash
pip install .
nvidia-sidecar
```
监听 `127.0.0.1:9190`,代理到 NVIDIA API。
## 环境变量
| 变量 | 默认值 | 说明 |
|------|--------|------|
| `SIDECAR_HOST` | `127.0.0.1` | 监听地址 |
| `SIDECAR_PORT` | `9190` | 监听端口 |
| `SIDECAR_METRICS_PORT` | `9191` | Metrics 端口 |
| `SIDECAR_UPSTREAM` | `https://integrate.api.nvidia.com/v1` | 上游 API 地址 |
| `SIDECAR_API_KEY` | — | NVIDIA API Key(必填) |
| `SIDECAR_RATE_RPM` | `40` | 每分钟请求数限制 |
| `SIDECAR_BUCKET_CAPACITY` | `40` | 令牌桶容量 |
| `SIDECAR_TIMEOUT` | `6000` | 上游请求超时(秒) |
| `SIDECAR_QUEUE_MAX` | `500` | 队列最大长度 |
| `SIDECAR_LOW_TIMEOUT` | `2.0` | 低优先级令牌等待超时(秒) |
| `SIDECAR_FALLBACK_PASSTHROUGH` | `true` | 队列满时是否直通上游 |
| `SIDECAR_LOG_LEVEL` | `INFO` | 日志级别 |
## YAML 配置
```yaml
listen_port: 9292
rate_rpm: 60
upstream_api_key: "nvapi-xxx"
```
```bash
nvidia-sidecar --config /etc/nvidia-sidecar.yaml
```
## API 端点
| 路径 | 方法 | 说明 |
|------|------|------|
| `/v1/chat/completions` | POST | OpenAI Chat Completions 代理 |
| `/v1/completions` | POST | OpenAI Completions 代理(legacy |
| `/v1/embeddings` | POST | OpenAI Embeddings 代理 |
| `/v1/models` | GET | 模型列表代理 |
| `/health` | GET | 健康检查 |
| `/metrics` | GET | 指标查询 |
## 架构
```
请求 → 网关识别 → [NVIDIA: 优先级排队 → 令牌桶限流] → httpx → NVIDIA API
→ [非 NVIDIA: 直通] → httpx → 上游
```
- **四级优先级**: URGENT > HIGH > NORMAL > LOW(通过 `X-Priority` header 指定)
- **队列满策略**: PASSTHROUGH(直通)/ REJECT503/ DROP_LOWEST(丢弃最低优先级)
- **令牌桶**: 40 RPM,线程安全,支持阻塞/非阻塞消费
+41
View File
@@ -0,0 +1,41 @@
"""
NVIDIA Sidecar 限流代理 — 核心代理模块。
为 OpenAI Chat Completions 兼容 API 提供四层防护:
1. 请求接收(FastAPI
2. 网关识别 → 非 NVIDIA 直通
3. 优先级排队 → 令牌桶限流
4. httpx 异步转发到 NVIDIA 上游
"""
from __future__ import annotations
from nvidia_sidecar.config import SidecarConfig, load_config
from nvidia_sidecar.rate_limiter import (
Priority,
TokenBucket,
is_nvidia_gateway,
normalize_gateway_name,
)
from nvidia_sidecar.priority_queue import (
PriorityQueueItem,
PriorityRequestQueue,
QueueFullError,
QueueFullPassthrough,
QueueFullPolicy,
)
__version__ = "0.1.0"
__all__ = [
"SidecarConfig",
"load_config",
"Priority",
"TokenBucket",
"is_nvidia_gateway",
"normalize_gateway_name",
"PriorityQueueItem",
"PriorityRequestQueue",
"QueueFullError",
"QueueFullPassthrough",
"QueueFullPolicy",
]
+216
View File
@@ -0,0 +1,216 @@
"""
NVIDIA Sidecar 限流代理 — 配置管理模块 (§3.1)
集中管理 Sidecar 运行参数,支持环境变量覆盖和 YAML 配置文件。
"""
from __future__ import annotations
import os
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@dataclass
class SidecarConfig:
"""Sidecar 运行配置数据类。
所有字段可通过环境变量覆盖,优先级:环境变量 > YAML 配置文件 > 默认值。
"""
# ---- 网络 ----
listen_host: str = field(
default="127.0.0.1",
metadata={"env": "SIDECAR_HOST"},
)
listen_port: int = field(
default=9190,
metadata={"env": "SIDECAR_PORT"},
)
metrics_port: int = field(
default=9191,
metadata={"env": "SIDECAR_METRICS_PORT"},
)
# ---- 上游 ----
upstream_url: str = field(
default="https://integrate.api.nvidia.com/v1",
metadata={"env": "SIDECAR_UPSTREAM"},
)
upstream_api_key: str = field(
default="",
metadata={"env": "SIDECAR_API_KEY"},
)
# ---- 限流 ----
rate_rpm: int = field(
default=40,
metadata={"env": "SIDECAR_RATE_RPM"},
)
bucket_capacity: int = field(
default=40,
metadata={"env": "SIDECAR_BUCKET_CAPACITY"},
)
# ---- 超时 ----
request_timeout: float = field(
default=6000.0,
metadata={"env": "SIDECAR_TIMEOUT"},
)
# ---- 队列 ----
queue_max_size: int = field(
default=500,
metadata={"env": "SIDECAR_QUEUE_MAX"},
)
low_priority_timeout: float = field(
default=2.0,
metadata={"env": "SIDECAR_LOW_TIMEOUT"},
)
# ---- 降级 ----
fallback_enabled_passthrough: bool = field(
default=True,
metadata={"env": "SIDECAR_FALLBACK_PASSTHROUGH"},
)
# ---- 日志 ----
log_level: str = field(
default="INFO",
metadata={"env": "SIDECAR_LOG_LEVEL"},
)
def _apply_env_overrides(config: SidecarConfig) -> SidecarConfig:
"""用环境变量覆盖配置字段。
遍历 SidecarConfig 的 dataclass fields,对每个声明了 ``metadata={"env": ...}``
的字段检查环境变量是否存在,存在则用对应类型转换后覆盖。
"""
import dataclasses as _dc
# 使用 typing.get_type_hints 解析 from __future__ import annotations
# 引入的字符串化类型注解 (PEP 563)
try:
resolved_types = __import__("typing").get_type_hints(type(config))
except Exception:
resolved_types = {}
for fld in _dc.fields(config):
env_key: str | None = fld.metadata.get("env")
if env_key is None:
continue
env_val = os.environ.get(env_key)
if env_val is None:
continue
target_type = resolved_types.get(fld.name, fld.type)
target_type_name: str = getattr(target_type, "__name__", str(target_type))
try:
if target_type is bool or target_type == "bool":
parsed: bool = env_val.strip().lower() in ("true", "1", "yes", "on")
setattr(config, fld.name, parsed)
elif target_type is int or target_type == "int":
setattr(config, fld.name, int(env_val))
elif target_type is float or target_type == "float":
setattr(config, fld.name, float(env_val))
else:
setattr(config, fld.name, env_val)
except (ValueError, TypeError) as exc:
warnings.warn(
f"无法将环境变量 {env_key}={env_val!r} 转换为 {target_type_name}: {exc}"
)
return config
def _validate_config(config: SidecarConfig) -> list[str]:
"""验证配置合理性,返回警告/问题列表。"""
issues: list[str] = []
# 端口冲突检查
if config.listen_port == config.metrics_port:
issues.append(
f"listen_port ({config.listen_port}) 与 metrics_port ({config.metrics_port}) 相同"
)
# rate_rpm 边界检查
if config.rate_rpm <= 0:
issues.append(
f"rate_rpm ({config.rate_rpm}) 无效,回退到默认值 40"
)
config.rate_rpm = 40
# queue_max_size 合理性
if config.queue_max_size <= 0:
issues.append(
f"queue_max_size ({config.queue_max_size}) 无效,回退到默认值 500"
)
config.queue_max_size = 500
# request_timeout 合理性
if config.request_timeout <= 0:
issues.append(
f"request_timeout ({config.request_timeout}) 无效,回退到默认值 6000"
)
config.request_timeout = 6000.0
return issues
def load_config(path: str | None = None) -> SidecarConfig:
"""加载 Sidecar 配置。
加载顺序(后者覆盖前者):
1. 默认值(SidecarConfig dataclass defaults
2. YAML 配置文件(如果 path 提供)
3. 环境变量覆盖
Args:
path: 可选 YAML 配置文件路径。为 None 时只使用默认值 + 环境变量。
Returns:
经过验证的 SidecarConfig 实例。
Raises:
FileNotFoundError: path 指定的文件不存在。
yaml.YAMLError: YAML 解析失败。
"""
config = SidecarConfig()
if path is not None:
import yaml
cfg_path = Path(path)
if not cfg_path.is_file():
raise FileNotFoundError(f"配置文件不存在: {cfg_path}")
try:
raw: dict[str, Any] = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
raise yaml.YAMLError(f"YAML 解析失败 ({cfg_path}): {exc}") from exc
# 覆盖已声明的字段
for fld_name in (
"listen_host", "listen_port", "metrics_port",
"upstream_url", "upstream_api_key",
"rate_rpm", "bucket_capacity",
"request_timeout",
"queue_max_size", "low_priority_timeout",
"fallback_enabled_passthrough",
"log_level",
):
if fld_name in raw:
setattr(config, fld_name, raw[fld_name])
# 环境变量覆盖(最高优先级)
config = _apply_env_overrides(config)
# 验证
issues = _validate_config(config)
for issue in issues:
warnings.warn(issue)
return config
+226
View File
@@ -0,0 +1,226 @@
"""
NVIDIA Sidecar 限流代理 — 四级优先级请求队列模块 (§3.3)
管理待处理的 NVIDIA API 请求,按优先级 + FIFO 出队。
支持三种队列满策略:PASSTHROUGH / REJECT / DROP_LOWEST。
"""
from __future__ import annotations
import asyncio
import heapq
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from nvidia_sidecar.rate_limiter import Priority
# ---------------------------------------------------------------------------
# 队列满策略
# ---------------------------------------------------------------------------
class QueueFullPolicy(str, Enum):
"""队列满时的处理策略。"""
PASSTHROUGH = "passthrough" # 直通上游,绕过排队(fail-open 子策略)
REJECT = "reject" # 返回 503 Service Unavailable
DROP_LOWEST = "drop_lowest" # 丢弃队列中最低优先级元素,插入新请求
# ---------------------------------------------------------------------------
# 队列元素
# ---------------------------------------------------------------------------
@dataclass(order=True)
class PriorityQueueItem:
"""优先级队列元素。
``sort_index`` 由 ``(priority, timestamp)`` 组成,
Python 的 ``__lt__`` 按字段顺序比较:先比 priority,再比 timestamp。
数值越小越优先(URGENT=1 优于 HIGH=2)。
"""
sort_index: tuple[int, float] = field(compare=True)
priority: Priority = field(compare=False)
request_id: str = field(compare=False)
payload: dict[str, Any] = field(compare=False)
enqueued_at: float = field(compare=False)
headers: dict[str, str] = field(default_factory=dict, compare=False)
# ---------------------------------------------------------------------------
# 优先级请求队列
# ---------------------------------------------------------------------------
class QueueFullError(Exception):
"""队列已满且策略为 REJECT 时抛出。"""
pass
class QueueFullPassthrough(Exception):
"""队列已满且策略为 PASSTHROUGH 时抛出,由调用方绕过队列直通上游。"""
pass
class PriorityRequestQueue:
"""异步线程安全的四级优先级请求队列。
内部使用 ``asyncio.Lock`` 保护并发操作,
基于 ``heapq`` + ``asyncio.Event`` 实现阻塞出队。
"""
def __init__(self, max_size: int = 500) -> None:
"""初始化优先级队列。
Args:
max_size: 队列最大容量。
Raises:
ValueError: max_size <= 0。
"""
if max_size <= 0:
raise ValueError(f"max_size 必须为正整数,当前值: {max_size}")
self.max_size: int = max_size
self._heap: list[PriorityQueueItem] = []
self._lock: asyncio.Lock = asyncio.Lock()
self._not_empty: asyncio.Event = asyncio.Event()
self._full_policy: QueueFullPolicy = QueueFullPolicy.PASSTHROUGH
# 统计
self._total_enqueued: int = 0
self._total_dequeued: int = 0
self._total_dropped: int = 0
# ---- 队列满策略 ----
def set_full_policy(self, policy: QueueFullPolicy) -> None:
"""设置队列满时的处理策略。
Args:
policy: QueueFullPolicy 枚举值。
"""
self._full_policy = policy
@property
def full_policy(self) -> QueueFullPolicy:
"""当前队列满策略。"""
return self._full_policy
# ---- 入队 ----
async def put(
self,
item: dict[str, Any],
priority: Priority = Priority.NORMAL,
headers: dict[str, str] | None = None,
) -> str:
"""将请求放入队列。
Args:
item: 请求体(JSON 序列化的 dict)。
priority: 请求优先级,默认 NORMAL。
headers: 原始请求 headers。
Returns:
分配的唯一 request_id。
Raises:
QueueFullError: 队列满且策略为 REJECT。
"""
request_id = str(uuid.uuid4())
headers = headers or {}
queue_item = PriorityQueueItem(
sort_index=(int(priority), time.monotonic()),
priority=priority,
request_id=request_id,
payload=item,
enqueued_at=time.monotonic(),
headers=headers,
)
async with self._lock:
queue_size = len(self._heap)
if queue_size >= self.max_size:
if self._full_policy == QueueFullPolicy.REJECT:
raise QueueFullError(
f"队列已满 ({queue_size}/{self.max_size}),策略: reject"
)
elif self._full_policy == QueueFullPolicy.DROP_LOWEST:
# 丢弃 heap 中优先级最低(值最大)的元素
# heap 是最小堆,找最大值需要遍历
max_val_item = max(self._heap, key=lambda x: x.sort_index)
self._heap.remove(max_val_item)
heapq.heapify(self._heap)
self._total_dropped += 1
# PASSTHROUGH 策略:不插入队列,抛异常让调用方绕过排队
else:
raise QueueFullPassthrough(
f"队列已满 ({queue_size}/{self.max_size}),策略: passthrough"
)
heapq.heappush(self._heap, queue_item)
self._total_enqueued += 1
self._not_empty.set()
return request_id
# ---- 出队 ----
async def get(self, timeout: float = 1.0) -> PriorityQueueItem | None:
"""从队列取出下一个元素(阻塞、优先级排序)。
Args:
timeout: 阻塞等待的最大秒数,默认 1.0。
Returns:
优先级最高的队列元素;超时无元素时返回 None。
"""
deadline = time.monotonic() + timeout
while True:
async with self._lock:
if self._heap:
item = heapq.heappop(self._heap)
self._total_dequeued += 1
if not self._heap:
self._not_empty.clear()
return item
# 队列为空,等待新元素入队
remaining = deadline - time.monotonic()
if remaining <= 0:
return None
try:
await asyncio.wait_for(
self._not_empty.wait(),
timeout=remaining,
)
except asyncio.TimeoutError:
return None
# ---- 状态查询 ----
async def get_queue_size(self) -> int:
"""返回当前队列长度。"""
async with self._lock:
return len(self._heap)
async def get_stats(self) -> dict[str, Any]:
"""返回队列统计信息。"""
async with self._lock:
depth_by_priority: dict[str, int] = {}
for item in self._heap:
key = item.priority.name
depth_by_priority[key] = depth_by_priority.get(key, 0) + 1
return {
"max_size": self.max_size,
"current_size": len(self._heap),
"total_enqueued": self._total_enqueued,
"total_dequeued": self._total_dequeued,
"total_dropped": self._total_dropped,
"depth_by_priority": depth_by_priority,
"full_policy": self._full_policy.value,
"utilization": len(self._heap) / self.max_size if self.max_size > 0 else 0.0,
}
+41
View File
@@ -0,0 +1,41 @@
[project]
name = "nvidia_sidecar"
version = "0.1.0"
description = "NVIDIA Sidecar 限流代理 — 为 NVIDIA API 提供优先级排队 + 令牌桶限流"
readme = "README.md"
license = { text = "MIT" }
requires-python = ">=3.12"
dependencies = [
"fastapi>=0.115",
"uvicorn[standard]>=0.34",
"httpx>=0.28",
"PyYAML>=6.0",
"structlog>=24.4",
]
[project.optional-dependencies]
dev = [
"pytest>=8.3",
"pytest-asyncio>=0.24",
"httpx>=0.28",
"mypy>=1.14",
]
[project.scripts]
nvidia-sidecar = "nvidia_sidecar.server:main"
[build-system]
requires = ["setuptools>=75", "wheel"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["."]
[tool.mypy]
python_version = "3.12"
strict = true
warn_return_any = true
warn_unused_configs = true
[[tool.mypy.overrides]]
module = "structlog.*"
ignore_missing_imports = true
+200
View File
@@ -0,0 +1,200 @@
"""
NVIDIA Sidecar 限流代理 — 令牌桶 + 网关识别模块 (§3.2)
从 BIZ-26 rate_limiter.py 提取核心限流逻辑,去除多线程调度器、缓存管理等。
保留:Priority, TokenBucket, is_nvidia_gateway, normalize_gateway_name。
"""
from __future__ import annotations
import time
import threading
from enum import IntEnum
from typing import Any
# ---------------------------------------------------------------------------
# 优先级枚举
# ---------------------------------------------------------------------------
class Priority(IntEnum):
"""请求优先级(数值越小优先级越高)。"""
URGENT = 1
HIGH = 2
NORMAL = 3
LOW = 4
# ---------------------------------------------------------------------------
# NVIDIA 网关别名集
# ---------------------------------------------------------------------------
NVIDIA_GATEWAY_ALIASES: set[str] = {
"nvidia",
"nvidia-gateway",
"nvidiavx",
"nvidiavx18088980513",
}
def is_nvidia_gateway(value: str | None) -> bool:
"""判断给定网关名/模型全路径是否属于 NVIDIA 网关。
Args:
value: 网关名(如 ``"nvidia"``)或模型全路径前缀
(如 ``"nvidia/deepseek-ai/deepseek-v4-pro"``)。
None 时直接返回 False。
Returns:
True 当 value 的 provider 部分匹配已知 NVIDIA 别名。
"""
if value is None:
return False
# 提取 provider 前缀:取 "/" 前第一个部分
provider = value.split("/", 1)[0].lower().strip()
return provider in NVIDIA_GATEWAY_ALIASES
def normalize_gateway_name(value: str | None) -> str | None:
"""规范化网关名:提取 provider 前缀并转为小写。
Args:
value: 网关名或模型全路径。None 时返回 None。
Returns:
provider 前缀的小写形式,或 None。
"""
if value is None:
return None
return value.split("/", 1)[0].lower().strip()
# ---------------------------------------------------------------------------
# 令牌桶(线程安全)
# ---------------------------------------------------------------------------
class TokenBucket:
"""线程安全的令牌桶实现。
支持固定速率令牌补充和消费,带有溢出保护和可选的阻塞等待。
"""
def __init__(self, rate: float = 40 / 60, capacity: int = 40) -> None:
"""初始化令牌桶。
Args:
rate: 令牌补充速率(令牌/秒)。默认 40/60 ≈ 0.667 token/s40 RPM)。
capacity: 桶最大容量(令牌数)。默认 40。
"""
self._rate: float = float(rate)
self._capacity: int = int(capacity)
self._tokens: float = float(capacity) # 启动时桶满
self._last_refill: float = time.monotonic()
self._lock: threading.Lock = threading.Lock()
# ---- 内部方法 ----
def _refill(self) -> None:
"""补充令牌(调用方需持有 _lock)。
根据距上次补充的时间差计算新增令牌数,不超过 capacity。
"""
now = time.monotonic()
elapsed = now - self._last_refill
if elapsed > 0 and self._rate > 0:
new_tokens = elapsed * self._rate
self._tokens = min(self._tokens + new_tokens, float(self._capacity))
self._last_refill = now
# ---- 公开方法 ----
def consume(self, tokens: int = 1) -> bool:
"""尝试立即消费令牌(非阻塞)。
Args:
tokens: 要消费的令牌数,默认 1。
Returns:
True 消费成功;False 令牌不足。
"""
if tokens <= 0:
return True
with self._lock:
self._refill()
if self._tokens >= tokens:
self._tokens -= tokens
return True
return False
def try_consume(self, tokens: int = 1, timeout: float = 2.0) -> bool:
"""尝试在指定时间内消费令牌(阻塞)。
Args:
tokens: 要消费的令牌数,默认 1。
timeout: 最大等待秒数,默认 2.0。
Returns:
True 在超时前成功消费;False 超时。
"""
if tokens <= 0:
return True
deadline = time.monotonic() + timeout
while True:
with self._lock:
self._refill()
if self._tokens >= tokens:
self._tokens -= tokens
return True
# 释放锁后计算剩余等待时间
remaining = deadline - time.monotonic()
if remaining <= 0:
return False
# 等待到下一个令牌应该补充的时间点
sleep_time = min(remaining, max(0.05, 1.0 / self._rate) if self._rate > 0 else remaining)
time.sleep(sleep_time)
def wait_for_token(self, timeout: float | None = None) -> bool:
"""等待并尝试消费 1 个令牌。
Args:
timeout: 最大等待秒数;None 表示无限等待(不推荐)。
Returns:
True 成功消费;False 超时。
"""
return self.try_consume(tokens=1, timeout=timeout if timeout is not None else float("inf"))
def get_status(self) -> dict[str, Any]:
"""获取令牌桶当前状态。
Returns:
包含 tokens, capacity, rate_per_minute, utilization 的字典。
"""
with self._lock:
self._refill()
rate_per_minute = self._rate * 60.0
utilization = 0.0 if self._capacity == 0 else (
(self._capacity - self._tokens) / self._capacity
)
return {
"tokens": round(self._tokens, 2),
"capacity": self._capacity,
"rate_per_minute": round(rate_per_minute, 1),
"utilization": round(utilization, 4),
}
# ---- 属性 ----
@property
def rate(self) -> float:
"""当前令牌补充速率(令牌/秒)。"""
return self._rate
@property
def capacity(self) -> int:
"""桶容量。"""
return self._capacity
+701
View File
@@ -0,0 +1,701 @@
"""
NVIDIA Sidecar 限流代理 — FastAPI 代理主入口 (§3.4)
完整的 API 代理链路:
接收 → 网关识别 → [NVIDIA: 排队 → 令牌限流] → httpx 转发 → 返回
非 NVIDIA 请求直通上游,NVIDIA 请求经过四级优先级队列 + 令牌桶限流。
"""
from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any
import httpx
import structlog
from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse
from nvidia_sidecar.config import load_config, SidecarConfig
from nvidia_sidecar.rate_limiter import (
Priority,
TokenBucket,
is_nvidia_gateway,
)
from nvidia_sidecar.priority_queue import (
PriorityRequestQueue,
QueueFullError,
QueueFullPassthrough,
QueueFullPolicy,
)
# ---------------------------------------------------------------------------
# 结构化日志
# ---------------------------------------------------------------------------
structlog.configure(
processors=[
structlog.stdlib.filter_by_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
structlog.dev.ConsoleRenderer(),
],
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
logger: structlog.stdlib.BoundLogger = structlog.get_logger("nvidia_sidecar")
# ---------------------------------------------------------------------------
# 全局状态(通过 lifespan 初始化,模块级引用方便路由访问)
# ---------------------------------------------------------------------------
_config: SidecarConfig
_http_client: httpx.AsyncClient
_priority_queue: PriorityRequestQueue
_token_bucket: TokenBucket
_pending_requests: dict[str, tuple[asyncio.Future[httpx.Response], float]]
"""request_id → (response future, enqueued_at) 的映射。"""
# 统计计数器
_stats: dict[str, int] = {
"total_requests": 0,
"nvidia_requests": 0,
"passthrough_requests": 0,
"ratelimited_requests": 0,
"queue_full_rejects": 0,
"upstream_errors": 0,
"start_time": 0,
}
# ---------------------------------------------------------------------------
# 工具函数
# ---------------------------------------------------------------------------
def _extract_model(body: dict[str, Any]) -> str | None:
"""从请求体中提取模型标识符(兼容 OpenAI Chat/Completions 格式)。
Args:
body: 已解析的 JSON 请求体。
Returns:
模型标识符字符串,或 None。
"""
if isinstance(body, dict):
return str(body.get("model", "")) or None
return None
def _resolve_priority(headers: dict[str, str]) -> Priority:
"""从请求 headers 解析优先级。
检查 ``X-Priority`` header,值为 ``urgent``/``high``/``normal``/``low``
不区分大小写。默认 NORMAL。
"""
raw = headers.get("x-priority", "").strip().lower()
mapping: dict[str, Priority] = {
"urgent": Priority.URGENT,
"high": Priority.HIGH,
"normal": Priority.NORMAL,
"low": Priority.LOW,
}
return mapping.get(raw, Priority.NORMAL)
# ---------------------------------------------------------------------------
# 上游转发
# ---------------------------------------------------------------------------
async def _forward_to_upstream(
method: str,
path: str,
body: bytes | None,
headers: dict[str, str],
stream: bool = False,
) -> httpx.Response:
"""将请求转发到 NVIDIA 上游 API。
Args:
method: HTTP 方法。
path: 请求路径(如 ``/v1/chat/completions``)。
body: 原始请求体 bytes。
headers: 要转发的请求 headers(会追加 Authorization)。
stream: 是否请求流式响应。
Returns:
httpx.Response 对象。
Raises:
httpx.HTTPError: HTTP 请求失败。
"""
upstream_url = _config.upstream_url.rstrip("/") + path
forward_headers: dict[str, str] = {
k: v for k, v in headers.items()
if k.lower() not in ("host", "content-length", "transfer-encoding")
}
if _config.upstream_api_key:
forward_headers["authorization"] = f"Bearer {_config.upstream_api_key}"
elif "authorization" not in {k.lower() for k in forward_headers}:
forward_headers["authorization"] = "Bearer nvidia"
try:
req = _http_client.build_request(
method=method,
url=upstream_url,
headers=forward_headers,
content=body,
timeout=_config.request_timeout,
)
response = await _http_client.send(req, stream=stream)
return response
except httpx.TimeoutException:
logger.warning("upstream_timeout", path=path, timeout=_config.request_timeout)
raise
except httpx.HTTPError as exc:
logger.error("upstream_error", path=path, error=str(exc))
raise
# ---------------------------------------------------------------------------
# worker 协程:消费优先级队列 + 令牌桶 + 转发
# ---------------------------------------------------------------------------
async def _worker_loop() -> None:
"""后台 worker:持续从优先级队列取请求 → 令牌限流 → 转发 → 设置 future 结果。"""
log = logger.bind(worker="main")
log.info("worker_started")
while True:
try:
queue_item = await _priority_queue.get(timeout=1.0)
if queue_item is None:
continue
request_id = queue_item.request_id
payload = queue_item.payload
headers = queue_item.headers
enqueued_at = queue_item.enqueued_at
# 查找对应的 pending future
pending_entry = _pending_requests.get(request_id)
if pending_entry is None:
log.warning("orphan_request", request_id=request_id)
continue
future, _ = pending_entry
# 低优先级令牌等待超时处理
if queue_item.priority == Priority.LOW:
# 放线程池执行阻塞的令牌桶调用
got_token = await asyncio.to_thread(
_token_bucket.try_consume,
tokens=1,
timeout=_config.low_priority_timeout,
)
if not got_token:
log.info("low_priority_timeout", request_id=request_id)
_stats["ratelimited_requests"] += 1
if not future.done():
future.set_exception(
_RateLimitedError(
f"低优先级请求令牌等待超时 ({_config.low_priority_timeout}s)"
)
)
_pending_requests.pop(request_id, None)
continue
else:
# 非低优先级:在 worker 内轮询等待令牌,避免重入队导致 future 悬挂
# (重入队会生成新 request_id,原 future 永不 resolve → 客户端永久 hang
got_token = await asyncio.to_thread(_token_bucket.consume, tokens=1)
if not got_token:
token_deadline = time.monotonic() + _config.request_timeout
while not got_token:
await asyncio.sleep(0.1)
got_token = await asyncio.to_thread(_token_bucket.consume, tokens=1)
if time.monotonic() > token_deadline:
break
if not got_token:
log.warning(
"token_wait_timeout",
request_id=request_id,
priority=queue_item.priority.name,
timeout=_config.request_timeout,
)
_stats["ratelimited_requests"] += 1
if not future.done():
future.set_exception(
_RateLimitedError(
f"令牌等待超时 ({_config.request_timeout:.0f}s)"
)
)
_pending_requests.pop(request_id, None)
continue
# 转发到上游
upstream_start = time.monotonic()
try:
path = headers.get("x-original-path", "/v1/chat/completions")
method = headers.get("x-original-method", "POST")
# 过滤内部 headers
clean_headers = {
k: v for k, v in headers.items()
if not k.startswith("x-original-") and not k.startswith("x-request-id")
}
resp = await _forward_to_upstream(
method=method,
path=path,
body=payload.get("_raw_body"),
headers=clean_headers,
stream=payload.get("stream", False),
)
upstream_latency = time.monotonic() - upstream_start
queue_latency = time.monotonic() - enqueued_at
total_latency = upstream_latency + queue_latency
log.info(
"request_completed",
request_id=request_id,
status=resp.status_code,
upstream_latency=round(upstream_latency, 3),
queue_latency=round(queue_latency, 3),
total_latency=round(total_latency, 3),
)
if not future.done():
future.set_result(resp)
except (httpx.HTTPError, OSError) as exc:
log.error("upstream_request_failed", request_id=request_id, error=str(exc))
_stats["upstream_errors"] += 1
if not future.done():
future.set_exception(exc)
_pending_requests.pop(request_id, None)
except asyncio.CancelledError:
log.info("worker_cancelled")
break
except Exception:
log.exception("worker_unexpected_error")
# ---------------------------------------------------------------------------
# PASSTHROUGH 直通路径(队列满 + PASSTHROUGH 策略)
# ---------------------------------------------------------------------------
async def _passthrough_with_rate_limit(
request: Request,
path: str,
body_bytes: bytes,
raw_headers: dict[str, str],
priority: Priority,
) -> Response:
"""队列满时的 PASSSTHROUGH 直通路径:仍受令牌桶限流,但不排队。
Args:
request: FastAPI Request。
path: 请求路径。
body_bytes: 原始请求体。
raw_headers: 请求 headers。
priority: 请求优先级。
Returns:
FastAPI Response。
"""
# 低优先级走令牌桶等待
if priority == Priority.LOW:
got_token = await asyncio.to_thread(
_token_bucket.try_consume,
tokens=1,
timeout=_config.low_priority_timeout,
)
if not got_token:
_stats["ratelimited_requests"] += 1
return JSONResponse(
status_code=429,
content={
"error": {
"message": f"令牌不足(队列满 + passthrough),超时 {_config.low_priority_timeout}s",
"type": "RateLimitedError",
}
},
)
else:
got_token = await asyncio.to_thread(_token_bucket.consume, tokens=1)
if not got_token:
# 非低优先级轮询等待
deadline = time.monotonic() + 30.0
while not got_token:
await asyncio.sleep(0.1)
got_token = await asyncio.to_thread(_token_bucket.consume, tokens=1)
if time.monotonic() > deadline:
_stats["ratelimited_requests"] += 1
return JSONResponse(
status_code=429,
content={
"error": {
"message": "令牌不足(队列满 + passthrough),等待超时 30s",
"type": "RateLimitedError",
}
},
)
# 拿到令牌,直接转发
try:
clean_headers = {k: v for k, v in raw_headers.items()}
resp = await _forward_to_upstream(
method=request.method,
path=path,
body=body_bytes if body_bytes else None,
headers=clean_headers,
stream=False,
)
return _build_response(resp)
except Exception as exc:
status, msg = _map_exception(exc)
logger.error("passthrough_error", path=path, error=str(exc))
return JSONResponse(
status_code=status,
content={"error": {"message": msg, "type": type(exc).__name__}},
)
# ---------------------------------------------------------------------------
# 自定义异常
# ---------------------------------------------------------------------------
class _RateLimitedError(Exception):
"""429 限流错误。"""
pass
# ---------------------------------------------------------------------------
# 异常处理矩阵 (§3.4)
# ---------------------------------------------------------------------------
_EXCEPTION_MATRIX: dict[type[Exception], tuple[int, str]] = {
_RateLimitedError: (429, "Too Many Requests — 令牌不足"),
QueueFullError: (503, "Service Unavailable — 队列已满"),
httpx.TimeoutException: (504, "Gateway Timeout — 上游超时"),
httpx.ConnectError: (502, "Bad Gateway — 上游连接失败"),
httpx.HTTPStatusError: (502, "Bad Gateway — 上游返回错误状态"),
}
def _map_exception(exc: Exception) -> tuple[int, str]:
"""将异常映射为 HTTP 状态码 + 错误信息。"""
for exc_type, (status, msg) in _EXCEPTION_MATRIX.items():
if isinstance(exc, exc_type):
return status, msg
return 500, f"Internal Server Error — {type(exc).__name__}"
# ---------------------------------------------------------------------------
# FastAPI 应用 + lifespan
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, Any]:
"""应用生命周期管理:初始化/清理全局资源。"""
global _config, _http_client, _priority_queue, _token_bucket, _pending_requests
# 启动
_config = load_config()
logging.getLogger().setLevel(_config.log_level.upper())
_http_client = httpx.AsyncClient(
timeout=httpx.Timeout(_config.request_timeout),
)
_priority_queue = PriorityRequestQueue(max_size=_config.queue_max_size)
_token_bucket = TokenBucket(
rate=_config.rate_rpm / 60.0,
capacity=_config.bucket_capacity,
)
_pending_requests = {}
_stats["start_time"] = int(time.time())
# 启动 worker 协程
worker_task = asyncio.create_task(_worker_loop())
logger.info(
"sidecar_started",
host=_config.listen_host,
port=_config.listen_port,
rate_rpm=_config.rate_rpm,
queue_max=_config.queue_max_size,
)
yield # app 运行中
# 关闭
worker_task.cancel()
try:
await worker_task
except asyncio.CancelledError:
pass
await _http_client.aclose()
logger.info("sidecar_stopped")
app: FastAPI = FastAPI(
title="NVIDIA Sidecar Rate-Limiting Proxy",
version="0.1.0",
lifespan=lifespan,
)
# ---------------------------------------------------------------------------
# 核心代理处理器
# ---------------------------------------------------------------------------
async def _handle_proxy_request(request: Request, path: str) -> Response:
"""统一的代理请求处理入口。
执行完整链路:
1. 解析请求体 → 提取 model
2. 网关识别 → 非 NVIDIA 直通
3. NVIDIA → 排队 + 令牌限流 + 转发
"""
_stats["total_requests"] += 1
# 解析请求
body_bytes: bytes = await request.body()
raw_headers: dict[str, str] = dict(request.headers)
# 尝试解析 JSON body
body_json: dict[str, Any] = {}
try:
if body_bytes:
body_json = __import__("json").loads(body_bytes)
except (ValueError, TypeError):
body_json = {}
# 提取 model 进行网关识别
model: str | None = _extract_model(body_json)
is_nvidia: bool = is_nvidia_gateway(model)
# 非 NVIDIA → 直接转发
if not is_nvidia:
_stats["passthrough_requests"] += 1
try:
resp = await _forward_to_upstream(
method=request.method,
path=path,
body=body_bytes if body_bytes else None,
headers=raw_headers,
stream=body_json.get("stream", False),
)
return _build_response(resp)
except Exception as exc:
status, msg = _map_exception(exc)
logger.error("passthrough_error", path=path, error=str(exc))
return JSONResponse(
status_code=status,
content={"error": {"message": msg, "type": type(exc).__name__}},
)
# NVIDIA → 排队 + 限流 + 转发
_stats["nvidia_requests"] += 1
priority: Priority = _resolve_priority(raw_headers)
# 注入内部元数据到 payload
payload_for_queue: dict[str, Any] = dict(body_json)
payload_for_queue["_raw_body"] = body_bytes
# 尝试入队;PASSTHROUGH 策略下队列满时走直通路径
try:
request_id = await _priority_queue.put(
item=payload_for_queue,
priority=priority,
headers={
**raw_headers,
"x-original-path": path,
"x-original-method": request.method,
},
)
except QueueFullError:
_stats["queue_full_rejects"] += 1
return JSONResponse(
status_code=503,
content={
"error": {
"message": "队列已满,当前策略: reject",
"type": "QueueFullError",
}
},
)
except QueueFullPassthrough:
# 队列满 + PASSTHROUGH:绕过排队,尝试令牌桶后直接转发
_stats["passthrough_requests"] += 1
logger.info("queue_full_passthrough", path=path)
return await _passthrough_with_rate_limit(request, path, body_bytes, raw_headers, priority)
# 创建 future 并注册到 pending
loop = asyncio.get_running_loop()
future: asyncio.Future[httpx.Response] = loop.create_future()
_pending_requests[request_id] = (future, time.monotonic())
try:
# 等待 worker 完成处理
resp = await future
return _build_response(resp)
except _RateLimitedError as exc:
return JSONResponse(
status_code=429,
content={
"error": {
"message": str(exc),
"type": "RateLimitedError",
}
},
)
except Exception as exc:
status, msg = _map_exception(exc)
logger.error("proxy_error", path=path, request_id=request_id, error=str(exc))
return JSONResponse(
status_code=status,
content={"error": {"message": msg, "type": type(exc).__name__}},
)
def _build_response(resp: httpx.Response) -> Response:
"""将 httpx.Response 转换为 FastAPI Response。
支持 JSON 和流式 (SSE) 两种响应类型。
"""
content_type = resp.headers.get("content-type", "")
# 流式响应 (SSE)
if "text/event-stream" in content_type or "stream" in content_type:
return StreamingResponse(
content=resp.aiter_bytes(),
status_code=resp.status_code,
headers={
k: v for k, v in resp.headers.items()
if k.lower() not in ("content-encoding", "transfer-encoding")
},
media_type="text/event-stream",
)
# 普通 JSON 响应
return Response(
content=resp.content,
status_code=resp.status_code,
headers={
k: v for k, v in resp.headers.items()
if k.lower() not in ("content-encoding", "transfer-encoding")
},
media_type=content_type or "application/json",
)
# ---------------------------------------------------------------------------
# 路由
# ---------------------------------------------------------------------------
@app.get("/health")
async def health() -> dict[str, Any]:
"""健康检查端点。"""
queue_stats = await _priority_queue.get_stats()
bucket_status = _token_bucket.get_status()
return {
"status": "ok",
"version": "0.1.0",
"uptime_seconds": int(time.time() - _stats["start_time"]) if _stats["start_time"] else 0,
"queue": queue_stats,
"token_bucket": bucket_status,
}
@app.get("/metrics")
async def metrics() -> dict[str, Any]:
"""Prometheus 格式 metrics 端点。"""
queue_stats = await _priority_queue.get_stats()
bucket_status = _token_bucket.get_status()
return {
"requests": {
"total": _stats["total_requests"],
"nvidia": _stats["nvidia_requests"],
"passthrough": _stats["passthrough_requests"],
"ratelimited": _stats["ratelimited_requests"],
},
"errors": {
"queue_full_rejects": _stats["queue_full_rejects"],
"upstream_errors": _stats["upstream_errors"],
},
"queue": queue_stats,
"token_bucket": bucket_status,
"uptime_seconds": int(time.time() - _stats["start_time"]) if _stats["start_time"] else 0,
}
# ---- OpenAI 兼容端点 ----
@app.post("/v1/chat/completions")
async def chat_completions(request: Request) -> Response:
"""OpenAI Chat Completions API 代理(含流式支持)。"""
return await _handle_proxy_request(request, "/v1/chat/completions")
@app.post("/v1/completions")
async def completions(request: Request) -> Response:
"""OpenAI Completions API 代理(legacy)。"""
return await _handle_proxy_request(request, "/v1/completions")
@app.post("/v1/embeddings")
async def embeddings(request: Request) -> Response:
"""OpenAI Embeddings API 代理。"""
return await _handle_proxy_request(request, "/v1/embeddings")
@app.get("/v1/models")
@app.get("/v1/models/{model_id:path}")
async def list_models(request: Request, model_id: str | None = None) -> Response:
"""OpenAI Models API 代理。"""
path = f"/v1/models/{model_id}" if model_id else "/v1/models"
return await _handle_proxy_request(request, path)
# ---- 通用代理(catch-all 用于非标准 NVIDIA 端点) ----
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
async def catch_all(request: Request, path: str) -> Response:
"""通用代理端点:转发任何未匹配的路径到上游。"""
target_path = f"/{path}" if not path.startswith("/") else path
return await _handle_proxy_request(request, target_path)
# ---------------------------------------------------------------------------
# 入口
# ---------------------------------------------------------------------------
def main() -> None:
"""开发/调试入口。"""
import uvicorn
cfg: SidecarConfig = load_config()
uvicorn.run(
"nvidia_sidecar.server:app",
host=cfg.listen_host,
port=cfg.listen_port,
log_level=cfg.log_level.lower(),
)
if __name__ == "__main__":
main()