diff --git a/services/nvidia_sidecar/README.md b/services/nvidia_sidecar/README.md new file mode 100644 index 0000000..5b662fe --- /dev/null +++ b/services/nvidia_sidecar/README.md @@ -0,0 +1,63 @@ +# NVIDIA Sidecar 限流代理 + +为 NVIDIA API 提供**优先级排队 + 令牌桶限流**的透明代理层。 + +## 快速启动 + +```bash +pip install . +nvidia-sidecar +``` + +监听 `127.0.0.1:9190`,代理到 NVIDIA API。 + +## 环境变量 + +| 变量 | 默认值 | 说明 | +|------|--------|------| +| `SIDECAR_HOST` | `127.0.0.1` | 监听地址 | +| `SIDECAR_PORT` | `9190` | 监听端口 | +| `SIDECAR_METRICS_PORT` | `9191` | Metrics 端口 | +| `SIDECAR_UPSTREAM` | `https://integrate.api.nvidia.com/v1` | 上游 API 地址 | +| `SIDECAR_API_KEY` | — | NVIDIA API Key(必填) | +| `SIDECAR_RATE_RPM` | `40` | 每分钟请求数限制 | +| `SIDECAR_BUCKET_CAPACITY` | `40` | 令牌桶容量 | +| `SIDECAR_TIMEOUT` | `6000` | 上游请求超时(秒) | +| `SIDECAR_QUEUE_MAX` | `500` | 队列最大长度 | +| `SIDECAR_LOW_TIMEOUT` | `2.0` | 低优先级令牌等待超时(秒) | +| `SIDECAR_FALLBACK_PASSTHROUGH` | `true` | 队列满时是否直通上游 | +| `SIDECAR_LOG_LEVEL` | `INFO` | 日志级别 | + +## YAML 配置 + +```yaml +listen_port: 9292 +rate_rpm: 60 +upstream_api_key: "nvapi-xxx" +``` + +```bash +nvidia-sidecar --config /etc/nvidia-sidecar.yaml +``` + +## API 端点 + +| 路径 | 方法 | 说明 | +|------|------|------| +| `/v1/chat/completions` | POST | OpenAI Chat Completions 代理 | +| `/v1/completions` | POST | OpenAI Completions 代理(legacy) | +| `/v1/embeddings` | POST | OpenAI Embeddings 代理 | +| `/v1/models` | GET | 模型列表代理 | +| `/health` | GET | 健康检查 | +| `/metrics` | GET | 指标查询 | + +## 架构 + +``` +请求 → 网关识别 → [NVIDIA: 优先级排队 → 令牌桶限流] → httpx → NVIDIA API + → [非 NVIDIA: 直通] → httpx → 上游 +``` + +- **四级优先级**: URGENT > HIGH > NORMAL > LOW(通过 `X-Priority` header 指定) +- **队列满策略**: PASSTHROUGH(直通)/ REJECT(503)/ DROP_LOWEST(丢弃最低优先级) +- **令牌桶**: 40 RPM,线程安全,支持阻塞/非阻塞消费 \ No newline at end of file diff --git a/services/nvidia_sidecar/__init__.py b/services/nvidia_sidecar/__init__.py new file mode 100644 index 0000000..c073124 --- /dev/null +++ b/services/nvidia_sidecar/__init__.py @@ -0,0 +1,41 @@ +""" +NVIDIA Sidecar 限流代理 — 核心代理模块。 + +为 OpenAI Chat Completions 兼容 API 提供四层防护: + 1. 请求接收(FastAPI) + 2. 网关识别 → 非 NVIDIA 直通 + 3. 优先级排队 → 令牌桶限流 + 4. httpx 异步转发到 NVIDIA 上游 +""" + +from __future__ import annotations + +from nvidia_sidecar.config import SidecarConfig, load_config +from nvidia_sidecar.rate_limiter import ( + Priority, + TokenBucket, + is_nvidia_gateway, + normalize_gateway_name, +) +from nvidia_sidecar.priority_queue import ( + PriorityQueueItem, + PriorityRequestQueue, + QueueFullError, + QueueFullPassthrough, + QueueFullPolicy, +) + +__version__ = "0.1.0" +__all__ = [ + "SidecarConfig", + "load_config", + "Priority", + "TokenBucket", + "is_nvidia_gateway", + "normalize_gateway_name", + "PriorityQueueItem", + "PriorityRequestQueue", + "QueueFullError", + "QueueFullPassthrough", + "QueueFullPolicy", +] \ No newline at end of file diff --git a/services/nvidia_sidecar/config.py b/services/nvidia_sidecar/config.py new file mode 100644 index 0000000..3176fb8 --- /dev/null +++ b/services/nvidia_sidecar/config.py @@ -0,0 +1,216 @@ +""" +NVIDIA Sidecar 限流代理 — 配置管理模块 (§3.1) + +集中管理 Sidecar 运行参数,支持环境变量覆盖和 YAML 配置文件。 +""" + +from __future__ import annotations + +import os +import warnings +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + + +@dataclass +class SidecarConfig: + """Sidecar 运行配置数据类。 + + 所有字段可通过环境变量覆盖,优先级:环境变量 > YAML 配置文件 > 默认值。 + """ + + # ---- 网络 ---- + listen_host: str = field( + default="127.0.0.1", + metadata={"env": "SIDECAR_HOST"}, + ) + listen_port: int = field( + default=9190, + metadata={"env": "SIDECAR_PORT"}, + ) + metrics_port: int = field( + default=9191, + metadata={"env": "SIDECAR_METRICS_PORT"}, + ) + + # ---- 上游 ---- + upstream_url: str = field( + default="https://integrate.api.nvidia.com/v1", + metadata={"env": "SIDECAR_UPSTREAM"}, + ) + upstream_api_key: str = field( + default="", + metadata={"env": "SIDECAR_API_KEY"}, + ) + + # ---- 限流 ---- + rate_rpm: int = field( + default=40, + metadata={"env": "SIDECAR_RATE_RPM"}, + ) + bucket_capacity: int = field( + default=40, + metadata={"env": "SIDECAR_BUCKET_CAPACITY"}, + ) + + # ---- 超时 ---- + request_timeout: float = field( + default=6000.0, + metadata={"env": "SIDECAR_TIMEOUT"}, + ) + + # ---- 队列 ---- + queue_max_size: int = field( + default=500, + metadata={"env": "SIDECAR_QUEUE_MAX"}, + ) + low_priority_timeout: float = field( + default=2.0, + metadata={"env": "SIDECAR_LOW_TIMEOUT"}, + ) + + # ---- 降级 ---- + fallback_enabled_passthrough: bool = field( + default=True, + metadata={"env": "SIDECAR_FALLBACK_PASSTHROUGH"}, + ) + + # ---- 日志 ---- + log_level: str = field( + default="INFO", + metadata={"env": "SIDECAR_LOG_LEVEL"}, + ) + + +def _apply_env_overrides(config: SidecarConfig) -> SidecarConfig: + """用环境变量覆盖配置字段。 + + 遍历 SidecarConfig 的 dataclass fields,对每个声明了 ``metadata={"env": ...}`` + 的字段检查环境变量是否存在,存在则用对应类型转换后覆盖。 + """ + import dataclasses as _dc + + # 使用 typing.get_type_hints 解析 from __future__ import annotations + # 引入的字符串化类型注解 (PEP 563) + try: + resolved_types = __import__("typing").get_type_hints(type(config)) + except Exception: + resolved_types = {} + + for fld in _dc.fields(config): + env_key: str | None = fld.metadata.get("env") + if env_key is None: + continue + env_val = os.environ.get(env_key) + if env_val is None: + continue + + target_type = resolved_types.get(fld.name, fld.type) + target_type_name: str = getattr(target_type, "__name__", str(target_type)) + try: + if target_type is bool or target_type == "bool": + parsed: bool = env_val.strip().lower() in ("true", "1", "yes", "on") + setattr(config, fld.name, parsed) + elif target_type is int or target_type == "int": + setattr(config, fld.name, int(env_val)) + elif target_type is float or target_type == "float": + setattr(config, fld.name, float(env_val)) + else: + setattr(config, fld.name, env_val) + except (ValueError, TypeError) as exc: + warnings.warn( + f"无法将环境变量 {env_key}={env_val!r} 转换为 {target_type_name}: {exc}" + ) + + return config + + +def _validate_config(config: SidecarConfig) -> list[str]: + """验证配置合理性,返回警告/问题列表。""" + issues: list[str] = [] + + # 端口冲突检查 + if config.listen_port == config.metrics_port: + issues.append( + f"listen_port ({config.listen_port}) 与 metrics_port ({config.metrics_port}) 相同" + ) + + # rate_rpm 边界检查 + if config.rate_rpm <= 0: + issues.append( + f"rate_rpm ({config.rate_rpm}) 无效,回退到默认值 40" + ) + config.rate_rpm = 40 + + # queue_max_size 合理性 + if config.queue_max_size <= 0: + issues.append( + f"queue_max_size ({config.queue_max_size}) 无效,回退到默认值 500" + ) + config.queue_max_size = 500 + + # request_timeout 合理性 + if config.request_timeout <= 0: + issues.append( + f"request_timeout ({config.request_timeout}) 无效,回退到默认值 6000" + ) + config.request_timeout = 6000.0 + + return issues + + +def load_config(path: str | None = None) -> SidecarConfig: + """加载 Sidecar 配置。 + + 加载顺序(后者覆盖前者): + 1. 默认值(SidecarConfig dataclass defaults) + 2. YAML 配置文件(如果 path 提供) + 3. 环境变量覆盖 + + Args: + path: 可选 YAML 配置文件路径。为 None 时只使用默认值 + 环境变量。 + + Returns: + 经过验证的 SidecarConfig 实例。 + + Raises: + FileNotFoundError: path 指定的文件不存在。 + yaml.YAMLError: YAML 解析失败。 + """ + config = SidecarConfig() + + if path is not None: + import yaml + + cfg_path = Path(path) + if not cfg_path.is_file(): + raise FileNotFoundError(f"配置文件不存在: {cfg_path}") + + try: + raw: dict[str, Any] = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) or {} + except yaml.YAMLError as exc: + raise yaml.YAMLError(f"YAML 解析失败 ({cfg_path}): {exc}") from exc + + # 覆盖已声明的字段 + for fld_name in ( + "listen_host", "listen_port", "metrics_port", + "upstream_url", "upstream_api_key", + "rate_rpm", "bucket_capacity", + "request_timeout", + "queue_max_size", "low_priority_timeout", + "fallback_enabled_passthrough", + "log_level", + ): + if fld_name in raw: + setattr(config, fld_name, raw[fld_name]) + + # 环境变量覆盖(最高优先级) + config = _apply_env_overrides(config) + + # 验证 + issues = _validate_config(config) + for issue in issues: + warnings.warn(issue) + + return config \ No newline at end of file diff --git a/services/nvidia_sidecar/priority_queue.py b/services/nvidia_sidecar/priority_queue.py new file mode 100644 index 0000000..3db1d05 --- /dev/null +++ b/services/nvidia_sidecar/priority_queue.py @@ -0,0 +1,226 @@ +""" +NVIDIA Sidecar 限流代理 — 四级优先级请求队列模块 (§3.3) + +管理待处理的 NVIDIA API 请求,按优先级 + FIFO 出队。 +支持三种队列满策略:PASSTHROUGH / REJECT / DROP_LOWEST。 +""" + +from __future__ import annotations + +import asyncio +import heapq +import time +import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from nvidia_sidecar.rate_limiter import Priority + + +# --------------------------------------------------------------------------- +# 队列满策略 +# --------------------------------------------------------------------------- + +class QueueFullPolicy(str, Enum): + """队列满时的处理策略。""" + PASSTHROUGH = "passthrough" # 直通上游,绕过排队(fail-open 子策略) + REJECT = "reject" # 返回 503 Service Unavailable + DROP_LOWEST = "drop_lowest" # 丢弃队列中最低优先级元素,插入新请求 + + +# --------------------------------------------------------------------------- +# 队列元素 +# --------------------------------------------------------------------------- + +@dataclass(order=True) +class PriorityQueueItem: + """优先级队列元素。 + + ``sort_index`` 由 ``(priority, timestamp)`` 组成, + Python 的 ``__lt__`` 按字段顺序比较:先比 priority,再比 timestamp。 + 数值越小越优先(URGENT=1 优于 HIGH=2)。 + """ + sort_index: tuple[int, float] = field(compare=True) + priority: Priority = field(compare=False) + request_id: str = field(compare=False) + payload: dict[str, Any] = field(compare=False) + enqueued_at: float = field(compare=False) + headers: dict[str, str] = field(default_factory=dict, compare=False) + + +# --------------------------------------------------------------------------- +# 优先级请求队列 +# --------------------------------------------------------------------------- + +class QueueFullError(Exception): + """队列已满且策略为 REJECT 时抛出。""" + pass + + +class QueueFullPassthrough(Exception): + """队列已满且策略为 PASSTHROUGH 时抛出,由调用方绕过队列直通上游。""" + pass + + +class PriorityRequestQueue: + """异步线程安全的四级优先级请求队列。 + + 内部使用 ``asyncio.Lock`` 保护并发操作, + 基于 ``heapq`` + ``asyncio.Event`` 实现阻塞出队。 + """ + + def __init__(self, max_size: int = 500) -> None: + """初始化优先级队列。 + + Args: + max_size: 队列最大容量。 + + Raises: + ValueError: max_size <= 0。 + """ + if max_size <= 0: + raise ValueError(f"max_size 必须为正整数,当前值: {max_size}") + self.max_size: int = max_size + self._heap: list[PriorityQueueItem] = [] + self._lock: asyncio.Lock = asyncio.Lock() + self._not_empty: asyncio.Event = asyncio.Event() + self._full_policy: QueueFullPolicy = QueueFullPolicy.PASSTHROUGH + + # 统计 + self._total_enqueued: int = 0 + self._total_dequeued: int = 0 + self._total_dropped: int = 0 + + # ---- 队列满策略 ---- + + def set_full_policy(self, policy: QueueFullPolicy) -> None: + """设置队列满时的处理策略。 + + Args: + policy: QueueFullPolicy 枚举值。 + """ + self._full_policy = policy + + @property + def full_policy(self) -> QueueFullPolicy: + """当前队列满策略。""" + return self._full_policy + + # ---- 入队 ---- + + async def put( + self, + item: dict[str, Any], + priority: Priority = Priority.NORMAL, + headers: dict[str, str] | None = None, + ) -> str: + """将请求放入队列。 + + Args: + item: 请求体(JSON 序列化的 dict)。 + priority: 请求优先级,默认 NORMAL。 + headers: 原始请求 headers。 + + Returns: + 分配的唯一 request_id。 + + Raises: + QueueFullError: 队列满且策略为 REJECT。 + """ + request_id = str(uuid.uuid4()) + headers = headers or {} + + queue_item = PriorityQueueItem( + sort_index=(int(priority), time.monotonic()), + priority=priority, + request_id=request_id, + payload=item, + enqueued_at=time.monotonic(), + headers=headers, + ) + + async with self._lock: + queue_size = len(self._heap) + if queue_size >= self.max_size: + if self._full_policy == QueueFullPolicy.REJECT: + raise QueueFullError( + f"队列已满 ({queue_size}/{self.max_size}),策略: reject" + ) + elif self._full_policy == QueueFullPolicy.DROP_LOWEST: + # 丢弃 heap 中优先级最低(值最大)的元素 + # heap 是最小堆,找最大值需要遍历 + max_val_item = max(self._heap, key=lambda x: x.sort_index) + self._heap.remove(max_val_item) + heapq.heapify(self._heap) + self._total_dropped += 1 + # PASSTHROUGH 策略:不插入队列,抛异常让调用方绕过排队 + else: + raise QueueFullPassthrough( + f"队列已满 ({queue_size}/{self.max_size}),策略: passthrough" + ) + + heapq.heappush(self._heap, queue_item) + self._total_enqueued += 1 + + self._not_empty.set() + return request_id + + # ---- 出队 ---- + + async def get(self, timeout: float = 1.0) -> PriorityQueueItem | None: + """从队列取出下一个元素(阻塞、优先级排序)。 + + Args: + timeout: 阻塞等待的最大秒数,默认 1.0。 + + Returns: + 优先级最高的队列元素;超时无元素时返回 None。 + """ + deadline = time.monotonic() + timeout + while True: + async with self._lock: + if self._heap: + item = heapq.heappop(self._heap) + self._total_dequeued += 1 + if not self._heap: + self._not_empty.clear() + return item + + # 队列为空,等待新元素入队 + remaining = deadline - time.monotonic() + if remaining <= 0: + return None + try: + await asyncio.wait_for( + self._not_empty.wait(), + timeout=remaining, + ) + except asyncio.TimeoutError: + return None + + # ---- 状态查询 ---- + + async def get_queue_size(self) -> int: + """返回当前队列长度。""" + async with self._lock: + return len(self._heap) + + async def get_stats(self) -> dict[str, Any]: + """返回队列统计信息。""" + async with self._lock: + depth_by_priority: dict[str, int] = {} + for item in self._heap: + key = item.priority.name + depth_by_priority[key] = depth_by_priority.get(key, 0) + 1 + + return { + "max_size": self.max_size, + "current_size": len(self._heap), + "total_enqueued": self._total_enqueued, + "total_dequeued": self._total_dequeued, + "total_dropped": self._total_dropped, + "depth_by_priority": depth_by_priority, + "full_policy": self._full_policy.value, + "utilization": len(self._heap) / self.max_size if self.max_size > 0 else 0.0, + } \ No newline at end of file diff --git a/services/nvidia_sidecar/pyproject.toml b/services/nvidia_sidecar/pyproject.toml new file mode 100644 index 0000000..c183a84 --- /dev/null +++ b/services/nvidia_sidecar/pyproject.toml @@ -0,0 +1,41 @@ +[project] +name = "nvidia_sidecar" +version = "0.1.0" +description = "NVIDIA Sidecar 限流代理 — 为 NVIDIA API 提供优先级排队 + 令牌桶限流" +readme = "README.md" +license = { text = "MIT" } +requires-python = ">=3.12" +dependencies = [ + "fastapi>=0.115", + "uvicorn[standard]>=0.34", + "httpx>=0.28", + "PyYAML>=6.0", + "structlog>=24.4", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.3", + "pytest-asyncio>=0.24", + "httpx>=0.28", + "mypy>=1.14", +] + +[project.scripts] +nvidia-sidecar = "nvidia_sidecar.server:main" + +[build-system] +requires = ["setuptools>=75", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["."] + +[tool.mypy] +python_version = "3.12" +strict = true +warn_return_any = true +warn_unused_configs = true +[[tool.mypy.overrides]] +module = "structlog.*" +ignore_missing_imports = true \ No newline at end of file diff --git a/services/nvidia_sidecar/rate_limiter.py b/services/nvidia_sidecar/rate_limiter.py new file mode 100644 index 0000000..3e2eaa3 --- /dev/null +++ b/services/nvidia_sidecar/rate_limiter.py @@ -0,0 +1,200 @@ +""" +NVIDIA Sidecar 限流代理 — 令牌桶 + 网关识别模块 (§3.2) + +从 BIZ-26 rate_limiter.py 提取核心限流逻辑,去除多线程调度器、缓存管理等。 +保留:Priority, TokenBucket, is_nvidia_gateway, normalize_gateway_name。 +""" + +from __future__ import annotations + +import time +import threading +from enum import IntEnum +from typing import Any + + +# --------------------------------------------------------------------------- +# 优先级枚举 +# --------------------------------------------------------------------------- + +class Priority(IntEnum): + """请求优先级(数值越小优先级越高)。""" + URGENT = 1 + HIGH = 2 + NORMAL = 3 + LOW = 4 + + +# --------------------------------------------------------------------------- +# NVIDIA 网关别名集 +# --------------------------------------------------------------------------- + +NVIDIA_GATEWAY_ALIASES: set[str] = { + "nvidia", + "nvidia-gateway", + "nvidiavx", + "nvidiavx18088980513", +} + + +def is_nvidia_gateway(value: str | None) -> bool: + """判断给定网关名/模型全路径是否属于 NVIDIA 网关。 + + Args: + value: 网关名(如 ``"nvidia"``)或模型全路径前缀 + (如 ``"nvidia/deepseek-ai/deepseek-v4-pro"``)。 + None 时直接返回 False。 + + Returns: + True 当 value 的 provider 部分匹配已知 NVIDIA 别名。 + """ + if value is None: + return False + + # 提取 provider 前缀:取 "/" 前第一个部分 + provider = value.split("/", 1)[0].lower().strip() + return provider in NVIDIA_GATEWAY_ALIASES + + +def normalize_gateway_name(value: str | None) -> str | None: + """规范化网关名:提取 provider 前缀并转为小写。 + + Args: + value: 网关名或模型全路径。None 时返回 None。 + + Returns: + provider 前缀的小写形式,或 None。 + """ + if value is None: + return None + return value.split("/", 1)[0].lower().strip() + + +# --------------------------------------------------------------------------- +# 令牌桶(线程安全) +# --------------------------------------------------------------------------- + +class TokenBucket: + """线程安全的令牌桶实现。 + + 支持固定速率令牌补充和消费,带有溢出保护和可选的阻塞等待。 + """ + + def __init__(self, rate: float = 40 / 60, capacity: int = 40) -> None: + """初始化令牌桶。 + + Args: + rate: 令牌补充速率(令牌/秒)。默认 40/60 ≈ 0.667 token/s(40 RPM)。 + capacity: 桶最大容量(令牌数)。默认 40。 + """ + self._rate: float = float(rate) + self._capacity: int = int(capacity) + self._tokens: float = float(capacity) # 启动时桶满 + self._last_refill: float = time.monotonic() + self._lock: threading.Lock = threading.Lock() + + # ---- 内部方法 ---- + + def _refill(self) -> None: + """补充令牌(调用方需持有 _lock)。 + + 根据距上次补充的时间差计算新增令牌数,不超过 capacity。 + """ + now = time.monotonic() + elapsed = now - self._last_refill + if elapsed > 0 and self._rate > 0: + new_tokens = elapsed * self._rate + self._tokens = min(self._tokens + new_tokens, float(self._capacity)) + self._last_refill = now + + # ---- 公开方法 ---- + + def consume(self, tokens: int = 1) -> bool: + """尝试立即消费令牌(非阻塞)。 + + Args: + tokens: 要消费的令牌数,默认 1。 + + Returns: + True 消费成功;False 令牌不足。 + """ + if tokens <= 0: + return True + + with self._lock: + self._refill() + if self._tokens >= tokens: + self._tokens -= tokens + return True + return False + + def try_consume(self, tokens: int = 1, timeout: float = 2.0) -> bool: + """尝试在指定时间内消费令牌(阻塞)。 + + Args: + tokens: 要消费的令牌数,默认 1。 + timeout: 最大等待秒数,默认 2.0。 + + Returns: + True 在超时前成功消费;False 超时。 + """ + if tokens <= 0: + return True + + deadline = time.monotonic() + timeout + while True: + with self._lock: + self._refill() + if self._tokens >= tokens: + self._tokens -= tokens + return True + + # 释放锁后计算剩余等待时间 + remaining = deadline - time.monotonic() + if remaining <= 0: + return False + # 等待到下一个令牌应该补充的时间点 + sleep_time = min(remaining, max(0.05, 1.0 / self._rate) if self._rate > 0 else remaining) + time.sleep(sleep_time) + + def wait_for_token(self, timeout: float | None = None) -> bool: + """等待并尝试消费 1 个令牌。 + + Args: + timeout: 最大等待秒数;None 表示无限等待(不推荐)。 + + Returns: + True 成功消费;False 超时。 + """ + return self.try_consume(tokens=1, timeout=timeout if timeout is not None else float("inf")) + + def get_status(self) -> dict[str, Any]: + """获取令牌桶当前状态。 + + Returns: + 包含 tokens, capacity, rate_per_minute, utilization 的字典。 + """ + with self._lock: + self._refill() + rate_per_minute = self._rate * 60.0 + utilization = 0.0 if self._capacity == 0 else ( + (self._capacity - self._tokens) / self._capacity + ) + return { + "tokens": round(self._tokens, 2), + "capacity": self._capacity, + "rate_per_minute": round(rate_per_minute, 1), + "utilization": round(utilization, 4), + } + + # ---- 属性 ---- + + @property + def rate(self) -> float: + """当前令牌补充速率(令牌/秒)。""" + return self._rate + + @property + def capacity(self) -> int: + """桶容量。""" + return self._capacity \ No newline at end of file diff --git a/services/nvidia_sidecar/server.py b/services/nvidia_sidecar/server.py new file mode 100644 index 0000000..d386c34 --- /dev/null +++ b/services/nvidia_sidecar/server.py @@ -0,0 +1,701 @@ +""" +NVIDIA Sidecar 限流代理 — FastAPI 代理主入口 (§3.4) + +完整的 API 代理链路: + 接收 → 网关识别 → [NVIDIA: 排队 → 令牌限流] → httpx 转发 → 返回 + +非 NVIDIA 请求直通上游,NVIDIA 请求经过四级优先级队列 + 令牌桶限流。 +""" + +from __future__ import annotations + +import asyncio +import logging +import time +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from typing import Any + +import httpx +import structlog +from fastapi import FastAPI, Request, Response +from fastapi.responses import JSONResponse, StreamingResponse + +from nvidia_sidecar.config import load_config, SidecarConfig +from nvidia_sidecar.rate_limiter import ( + Priority, + TokenBucket, + is_nvidia_gateway, +) +from nvidia_sidecar.priority_queue import ( + PriorityRequestQueue, + QueueFullError, + QueueFullPassthrough, + QueueFullPolicy, +) + +# --------------------------------------------------------------------------- +# 结构化日志 +# --------------------------------------------------------------------------- + +structlog.configure( + processors=[ + structlog.stdlib.filter_by_level, + structlog.stdlib.add_logger_name, + structlog.stdlib.add_log_level, + structlog.stdlib.PositionalArgumentsFormatter(), + structlog.processors.TimeStamper(fmt="iso"), + structlog.processors.StackInfoRenderer(), + structlog.processors.format_exc_info, + structlog.processors.UnicodeDecoder(), + structlog.dev.ConsoleRenderer(), + ], + context_class=dict, + logger_factory=structlog.stdlib.LoggerFactory(), + wrapper_class=structlog.stdlib.BoundLogger, + cache_logger_on_first_use=True, +) +logger: structlog.stdlib.BoundLogger = structlog.get_logger("nvidia_sidecar") + + +# --------------------------------------------------------------------------- +# 全局状态(通过 lifespan 初始化,模块级引用方便路由访问) +# --------------------------------------------------------------------------- + +_config: SidecarConfig +_http_client: httpx.AsyncClient +_priority_queue: PriorityRequestQueue +_token_bucket: TokenBucket +_pending_requests: dict[str, tuple[asyncio.Future[httpx.Response], float]] +"""request_id → (response future, enqueued_at) 的映射。""" + +# 统计计数器 +_stats: dict[str, int] = { + "total_requests": 0, + "nvidia_requests": 0, + "passthrough_requests": 0, + "ratelimited_requests": 0, + "queue_full_rejects": 0, + "upstream_errors": 0, + "start_time": 0, +} + + +# --------------------------------------------------------------------------- +# 工具函数 +# --------------------------------------------------------------------------- + +def _extract_model(body: dict[str, Any]) -> str | None: + """从请求体中提取模型标识符(兼容 OpenAI Chat/Completions 格式)。 + + Args: + body: 已解析的 JSON 请求体。 + + Returns: + 模型标识符字符串,或 None。 + """ + if isinstance(body, dict): + return str(body.get("model", "")) or None + return None + + +def _resolve_priority(headers: dict[str, str]) -> Priority: + """从请求 headers 解析优先级。 + + 检查 ``X-Priority`` header,值为 ``urgent``/``high``/``normal``/``low``, + 不区分大小写。默认 NORMAL。 + """ + raw = headers.get("x-priority", "").strip().lower() + mapping: dict[str, Priority] = { + "urgent": Priority.URGENT, + "high": Priority.HIGH, + "normal": Priority.NORMAL, + "low": Priority.LOW, + } + return mapping.get(raw, Priority.NORMAL) + + +# --------------------------------------------------------------------------- +# 上游转发 +# --------------------------------------------------------------------------- + +async def _forward_to_upstream( + method: str, + path: str, + body: bytes | None, + headers: dict[str, str], + stream: bool = False, +) -> httpx.Response: + """将请求转发到 NVIDIA 上游 API。 + + Args: + method: HTTP 方法。 + path: 请求路径(如 ``/v1/chat/completions``)。 + body: 原始请求体 bytes。 + headers: 要转发的请求 headers(会追加 Authorization)。 + stream: 是否请求流式响应。 + + Returns: + httpx.Response 对象。 + + Raises: + httpx.HTTPError: HTTP 请求失败。 + """ + upstream_url = _config.upstream_url.rstrip("/") + path + forward_headers: dict[str, str] = { + k: v for k, v in headers.items() + if k.lower() not in ("host", "content-length", "transfer-encoding") + } + if _config.upstream_api_key: + forward_headers["authorization"] = f"Bearer {_config.upstream_api_key}" + elif "authorization" not in {k.lower() for k in forward_headers}: + forward_headers["authorization"] = "Bearer nvidia" + + try: + req = _http_client.build_request( + method=method, + url=upstream_url, + headers=forward_headers, + content=body, + timeout=_config.request_timeout, + ) + response = await _http_client.send(req, stream=stream) + return response + except httpx.TimeoutException: + logger.warning("upstream_timeout", path=path, timeout=_config.request_timeout) + raise + except httpx.HTTPError as exc: + logger.error("upstream_error", path=path, error=str(exc)) + raise + + +# --------------------------------------------------------------------------- +# worker 协程:消费优先级队列 + 令牌桶 + 转发 +# --------------------------------------------------------------------------- + +async def _worker_loop() -> None: + """后台 worker:持续从优先级队列取请求 → 令牌限流 → 转发 → 设置 future 结果。""" + log = logger.bind(worker="main") + log.info("worker_started") + + while True: + try: + queue_item = await _priority_queue.get(timeout=1.0) + if queue_item is None: + continue + + request_id = queue_item.request_id + payload = queue_item.payload + headers = queue_item.headers + enqueued_at = queue_item.enqueued_at + + # 查找对应的 pending future + pending_entry = _pending_requests.get(request_id) + if pending_entry is None: + log.warning("orphan_request", request_id=request_id) + continue + future, _ = pending_entry + + # 低优先级令牌等待超时处理 + if queue_item.priority == Priority.LOW: + # 放线程池执行阻塞的令牌桶调用 + got_token = await asyncio.to_thread( + _token_bucket.try_consume, + tokens=1, + timeout=_config.low_priority_timeout, + ) + if not got_token: + log.info("low_priority_timeout", request_id=request_id) + _stats["ratelimited_requests"] += 1 + if not future.done(): + future.set_exception( + _RateLimitedError( + f"低优先级请求令牌等待超时 ({_config.low_priority_timeout}s)" + ) + ) + _pending_requests.pop(request_id, None) + continue + else: + # 非低优先级:在 worker 内轮询等待令牌,避免重入队导致 future 悬挂 + # (重入队会生成新 request_id,原 future 永不 resolve → 客户端永久 hang) + got_token = await asyncio.to_thread(_token_bucket.consume, tokens=1) + if not got_token: + token_deadline = time.monotonic() + _config.request_timeout + while not got_token: + await asyncio.sleep(0.1) + got_token = await asyncio.to_thread(_token_bucket.consume, tokens=1) + if time.monotonic() > token_deadline: + break + if not got_token: + log.warning( + "token_wait_timeout", + request_id=request_id, + priority=queue_item.priority.name, + timeout=_config.request_timeout, + ) + _stats["ratelimited_requests"] += 1 + if not future.done(): + future.set_exception( + _RateLimitedError( + f"令牌等待超时 ({_config.request_timeout:.0f}s)" + ) + ) + _pending_requests.pop(request_id, None) + continue + + # 转发到上游 + upstream_start = time.monotonic() + try: + path = headers.get("x-original-path", "/v1/chat/completions") + method = headers.get("x-original-method", "POST") + # 过滤内部 headers + clean_headers = { + k: v for k, v in headers.items() + if not k.startswith("x-original-") and not k.startswith("x-request-id") + } + + resp = await _forward_to_upstream( + method=method, + path=path, + body=payload.get("_raw_body"), + headers=clean_headers, + stream=payload.get("stream", False), + ) + + upstream_latency = time.monotonic() - upstream_start + queue_latency = time.monotonic() - enqueued_at + total_latency = upstream_latency + queue_latency + + log.info( + "request_completed", + request_id=request_id, + status=resp.status_code, + upstream_latency=round(upstream_latency, 3), + queue_latency=round(queue_latency, 3), + total_latency=round(total_latency, 3), + ) + + if not future.done(): + future.set_result(resp) + + except (httpx.HTTPError, OSError) as exc: + log.error("upstream_request_failed", request_id=request_id, error=str(exc)) + _stats["upstream_errors"] += 1 + if not future.done(): + future.set_exception(exc) + + _pending_requests.pop(request_id, None) + + except asyncio.CancelledError: + log.info("worker_cancelled") + break + except Exception: + log.exception("worker_unexpected_error") + + +# --------------------------------------------------------------------------- +# PASSTHROUGH 直通路径(队列满 + PASSTHROUGH 策略) +# --------------------------------------------------------------------------- + +async def _passthrough_with_rate_limit( + request: Request, + path: str, + body_bytes: bytes, + raw_headers: dict[str, str], + priority: Priority, +) -> Response: + """队列满时的 PASSSTHROUGH 直通路径:仍受令牌桶限流,但不排队。 + + Args: + request: FastAPI Request。 + path: 请求路径。 + body_bytes: 原始请求体。 + raw_headers: 请求 headers。 + priority: 请求优先级。 + + Returns: + FastAPI Response。 + """ + # 低优先级走令牌桶等待 + if priority == Priority.LOW: + got_token = await asyncio.to_thread( + _token_bucket.try_consume, + tokens=1, + timeout=_config.low_priority_timeout, + ) + if not got_token: + _stats["ratelimited_requests"] += 1 + return JSONResponse( + status_code=429, + content={ + "error": { + "message": f"令牌不足(队列满 + passthrough),超时 {_config.low_priority_timeout}s", + "type": "RateLimitedError", + } + }, + ) + else: + got_token = await asyncio.to_thread(_token_bucket.consume, tokens=1) + if not got_token: + # 非低优先级轮询等待 + deadline = time.monotonic() + 30.0 + while not got_token: + await asyncio.sleep(0.1) + got_token = await asyncio.to_thread(_token_bucket.consume, tokens=1) + if time.monotonic() > deadline: + _stats["ratelimited_requests"] += 1 + return JSONResponse( + status_code=429, + content={ + "error": { + "message": "令牌不足(队列满 + passthrough),等待超时 30s", + "type": "RateLimitedError", + } + }, + ) + + # 拿到令牌,直接转发 + try: + clean_headers = {k: v for k, v in raw_headers.items()} + resp = await _forward_to_upstream( + method=request.method, + path=path, + body=body_bytes if body_bytes else None, + headers=clean_headers, + stream=False, + ) + return _build_response(resp) + except Exception as exc: + status, msg = _map_exception(exc) + logger.error("passthrough_error", path=path, error=str(exc)) + return JSONResponse( + status_code=status, + content={"error": {"message": msg, "type": type(exc).__name__}}, + ) + + +# --------------------------------------------------------------------------- +# 自定义异常 +# --------------------------------------------------------------------------- + +class _RateLimitedError(Exception): + """429 限流错误。""" + pass + + +# --------------------------------------------------------------------------- +# 异常处理矩阵 (§3.4) +# --------------------------------------------------------------------------- + +_EXCEPTION_MATRIX: dict[type[Exception], tuple[int, str]] = { + _RateLimitedError: (429, "Too Many Requests — 令牌不足"), + QueueFullError: (503, "Service Unavailable — 队列已满"), + httpx.TimeoutException: (504, "Gateway Timeout — 上游超时"), + httpx.ConnectError: (502, "Bad Gateway — 上游连接失败"), + httpx.HTTPStatusError: (502, "Bad Gateway — 上游返回错误状态"), +} + + +def _map_exception(exc: Exception) -> tuple[int, str]: + """将异常映射为 HTTP 状态码 + 错误信息。""" + for exc_type, (status, msg) in _EXCEPTION_MATRIX.items(): + if isinstance(exc, exc_type): + return status, msg + return 500, f"Internal Server Error — {type(exc).__name__}" + + +# --------------------------------------------------------------------------- +# FastAPI 应用 + lifespan +# --------------------------------------------------------------------------- + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, Any]: + """应用生命周期管理:初始化/清理全局资源。""" + global _config, _http_client, _priority_queue, _token_bucket, _pending_requests + + # 启动 + _config = load_config() + logging.getLogger().setLevel(_config.log_level.upper()) + + _http_client = httpx.AsyncClient( + timeout=httpx.Timeout(_config.request_timeout), + ) + _priority_queue = PriorityRequestQueue(max_size=_config.queue_max_size) + _token_bucket = TokenBucket( + rate=_config.rate_rpm / 60.0, + capacity=_config.bucket_capacity, + ) + _pending_requests = {} + _stats["start_time"] = int(time.time()) + + # 启动 worker 协程 + worker_task = asyncio.create_task(_worker_loop()) + + logger.info( + "sidecar_started", + host=_config.listen_host, + port=_config.listen_port, + rate_rpm=_config.rate_rpm, + queue_max=_config.queue_max_size, + ) + + yield # app 运行中 + + # 关闭 + worker_task.cancel() + try: + await worker_task + except asyncio.CancelledError: + pass + + await _http_client.aclose() + logger.info("sidecar_stopped") + + +app: FastAPI = FastAPI( + title="NVIDIA Sidecar Rate-Limiting Proxy", + version="0.1.0", + lifespan=lifespan, +) + + +# --------------------------------------------------------------------------- +# 核心代理处理器 +# --------------------------------------------------------------------------- + +async def _handle_proxy_request(request: Request, path: str) -> Response: + """统一的代理请求处理入口。 + + 执行完整链路: + 1. 解析请求体 → 提取 model + 2. 网关识别 → 非 NVIDIA 直通 + 3. NVIDIA → 排队 + 令牌限流 + 转发 + """ + _stats["total_requests"] += 1 + + # 解析请求 + body_bytes: bytes = await request.body() + raw_headers: dict[str, str] = dict(request.headers) + + # 尝试解析 JSON body + body_json: dict[str, Any] = {} + try: + if body_bytes: + body_json = __import__("json").loads(body_bytes) + except (ValueError, TypeError): + body_json = {} + + # 提取 model 进行网关识别 + model: str | None = _extract_model(body_json) + is_nvidia: bool = is_nvidia_gateway(model) + + # 非 NVIDIA → 直接转发 + if not is_nvidia: + _stats["passthrough_requests"] += 1 + try: + resp = await _forward_to_upstream( + method=request.method, + path=path, + body=body_bytes if body_bytes else None, + headers=raw_headers, + stream=body_json.get("stream", False), + ) + return _build_response(resp) + except Exception as exc: + status, msg = _map_exception(exc) + logger.error("passthrough_error", path=path, error=str(exc)) + return JSONResponse( + status_code=status, + content={"error": {"message": msg, "type": type(exc).__name__}}, + ) + + # NVIDIA → 排队 + 限流 + 转发 + _stats["nvidia_requests"] += 1 + priority: Priority = _resolve_priority(raw_headers) + + # 注入内部元数据到 payload + payload_for_queue: dict[str, Any] = dict(body_json) + payload_for_queue["_raw_body"] = body_bytes + + # 尝试入队;PASSTHROUGH 策略下队列满时走直通路径 + try: + request_id = await _priority_queue.put( + item=payload_for_queue, + priority=priority, + headers={ + **raw_headers, + "x-original-path": path, + "x-original-method": request.method, + }, + ) + except QueueFullError: + _stats["queue_full_rejects"] += 1 + return JSONResponse( + status_code=503, + content={ + "error": { + "message": "队列已满,当前策略: reject", + "type": "QueueFullError", + } + }, + ) + except QueueFullPassthrough: + # 队列满 + PASSTHROUGH:绕过排队,尝试令牌桶后直接转发 + _stats["passthrough_requests"] += 1 + logger.info("queue_full_passthrough", path=path) + return await _passthrough_with_rate_limit(request, path, body_bytes, raw_headers, priority) + + # 创建 future 并注册到 pending + loop = asyncio.get_running_loop() + future: asyncio.Future[httpx.Response] = loop.create_future() + _pending_requests[request_id] = (future, time.monotonic()) + + try: + # 等待 worker 完成处理 + resp = await future + return _build_response(resp) + except _RateLimitedError as exc: + return JSONResponse( + status_code=429, + content={ + "error": { + "message": str(exc), + "type": "RateLimitedError", + } + }, + ) + except Exception as exc: + status, msg = _map_exception(exc) + logger.error("proxy_error", path=path, request_id=request_id, error=str(exc)) + return JSONResponse( + status_code=status, + content={"error": {"message": msg, "type": type(exc).__name__}}, + ) + + +def _build_response(resp: httpx.Response) -> Response: + """将 httpx.Response 转换为 FastAPI Response。 + + 支持 JSON 和流式 (SSE) 两种响应类型。 + """ + content_type = resp.headers.get("content-type", "") + + # 流式响应 (SSE) + if "text/event-stream" in content_type or "stream" in content_type: + return StreamingResponse( + content=resp.aiter_bytes(), + status_code=resp.status_code, + headers={ + k: v for k, v in resp.headers.items() + if k.lower() not in ("content-encoding", "transfer-encoding") + }, + media_type="text/event-stream", + ) + + # 普通 JSON 响应 + return Response( + content=resp.content, + status_code=resp.status_code, + headers={ + k: v for k, v in resp.headers.items() + if k.lower() not in ("content-encoding", "transfer-encoding") + }, + media_type=content_type or "application/json", + ) + + +# --------------------------------------------------------------------------- +# 路由 +# --------------------------------------------------------------------------- + +@app.get("/health") +async def health() -> dict[str, Any]: + """健康检查端点。""" + queue_stats = await _priority_queue.get_stats() + bucket_status = _token_bucket.get_status() + return { + "status": "ok", + "version": "0.1.0", + "uptime_seconds": int(time.time() - _stats["start_time"]) if _stats["start_time"] else 0, + "queue": queue_stats, + "token_bucket": bucket_status, + } + + +@app.get("/metrics") +async def metrics() -> dict[str, Any]: + """Prometheus 格式 metrics 端点。""" + queue_stats = await _priority_queue.get_stats() + bucket_status = _token_bucket.get_status() + return { + "requests": { + "total": _stats["total_requests"], + "nvidia": _stats["nvidia_requests"], + "passthrough": _stats["passthrough_requests"], + "ratelimited": _stats["ratelimited_requests"], + }, + "errors": { + "queue_full_rejects": _stats["queue_full_rejects"], + "upstream_errors": _stats["upstream_errors"], + }, + "queue": queue_stats, + "token_bucket": bucket_status, + "uptime_seconds": int(time.time() - _stats["start_time"]) if _stats["start_time"] else 0, + } + + +# ---- OpenAI 兼容端点 ---- + +@app.post("/v1/chat/completions") +async def chat_completions(request: Request) -> Response: + """OpenAI Chat Completions API 代理(含流式支持)。""" + return await _handle_proxy_request(request, "/v1/chat/completions") + + +@app.post("/v1/completions") +async def completions(request: Request) -> Response: + """OpenAI Completions API 代理(legacy)。""" + return await _handle_proxy_request(request, "/v1/completions") + + +@app.post("/v1/embeddings") +async def embeddings(request: Request) -> Response: + """OpenAI Embeddings API 代理。""" + return await _handle_proxy_request(request, "/v1/embeddings") + + +@app.get("/v1/models") +@app.get("/v1/models/{model_id:path}") +async def list_models(request: Request, model_id: str | None = None) -> Response: + """OpenAI Models API 代理。""" + path = f"/v1/models/{model_id}" if model_id else "/v1/models" + return await _handle_proxy_request(request, path) + + +# ---- 通用代理(catch-all 用于非标准 NVIDIA 端点) ---- + +@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) +async def catch_all(request: Request, path: str) -> Response: + """通用代理端点:转发任何未匹配的路径到上游。""" + target_path = f"/{path}" if not path.startswith("/") else path + return await _handle_proxy_request(request, target_path) + + +# --------------------------------------------------------------------------- +# 入口 +# --------------------------------------------------------------------------- + +def main() -> None: + """开发/调试入口。""" + import uvicorn + cfg: SidecarConfig = load_config() + uvicorn.run( + "nvidia_sidecar.server:app", + host=cfg.listen_host, + port=cfg.listen_port, + log_level=cfg.log_level.lower(), + ) + + +if __name__ == "__main__": + main() \ No newline at end of file