From 611ebd11a806a2c2eb279152a78f1ca585a6347b Mon Sep 17 00:00:00 2001 From: bizwings Date: Thu, 25 Jun 2026 16:39:01 +0800 Subject: [PATCH] feat(sidecar-v2): implement multi-pool provider proxy with cooldown, rate limiting, WebUI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BIZ-52 Step3 开发实现: - storage: backend/usage/cooldown/config CRUD with SQLite WAL - crypto: AES-256-GCM API key encryption - pool_manager: primary/fallback pool routing - cooldown_manager: 429 exponential backoff cooldown - rate_limiter: per-backend token bucket RPM control - router: model → backend routing with pool priority - proxy: multi-pool request forwarding with retry - server: FastAPI admin API + OpenAI-compatible proxy + SSE - dashboard: WebUI with provider CRUD, stats, charts Co-authored-by: multica-agent --- services/nvidia_sidecar/README.md | 77 +++ services/nvidia_sidecar/__init__.py | 1 + services/nvidia_sidecar/config.py | 165 +++++ services/nvidia_sidecar/cooldown_manager.py | 116 ++++ services/nvidia_sidecar/crypto.py | 108 ++++ services/nvidia_sidecar/dashboard.html | 605 ++++++++++++++++++ services/nvidia_sidecar/main.py | 17 + services/nvidia_sidecar/pool_manager.py | 83 +++ services/nvidia_sidecar/proxy.py | 321 ++++++++++ services/nvidia_sidecar/rate_limiter.py | 111 ++++ services/nvidia_sidecar/router.py | 62 ++ services/nvidia_sidecar/server.py | 477 ++++++++++++++ services/nvidia_sidecar/storage/__init__.py | 1 + .../nvidia_sidecar/storage/backend_store.py | 252 ++++++++ .../nvidia_sidecar/storage/config_store.py | 55 ++ .../nvidia_sidecar/storage/cooldown_store.py | 74 +++ services/nvidia_sidecar/storage/db.py | 193 ++++++ services/nvidia_sidecar/storage/models.py | 161 +++++ .../nvidia_sidecar/storage/usage_store.py | 155 +++++ 19 files changed, 3034 insertions(+) create mode 100644 services/nvidia_sidecar/README.md create mode 100644 services/nvidia_sidecar/__init__.py create mode 100644 services/nvidia_sidecar/config.py create mode 100644 services/nvidia_sidecar/cooldown_manager.py create mode 100644 services/nvidia_sidecar/crypto.py create mode 100644 services/nvidia_sidecar/dashboard.html create mode 100644 services/nvidia_sidecar/main.py create mode 100644 services/nvidia_sidecar/pool_manager.py create mode 100644 services/nvidia_sidecar/proxy.py create mode 100644 services/nvidia_sidecar/rate_limiter.py create mode 100644 services/nvidia_sidecar/router.py create mode 100644 services/nvidia_sidecar/server.py create mode 100644 services/nvidia_sidecar/storage/__init__.py create mode 100644 services/nvidia_sidecar/storage/backend_store.py create mode 100644 services/nvidia_sidecar/storage/config_store.py create mode 100644 services/nvidia_sidecar/storage/cooldown_store.py create mode 100644 services/nvidia_sidecar/storage/db.py create mode 100644 services/nvidia_sidecar/storage/models.py create mode 100644 services/nvidia_sidecar/storage/usage_store.py diff --git a/services/nvidia_sidecar/README.md b/services/nvidia_sidecar/README.md new file mode 100644 index 0000000..056e871 --- /dev/null +++ b/services/nvidia_sidecar/README.md @@ -0,0 +1,77 @@ +# Sidecar V2 — Multi-Pool Provider Proxy + +## 概述 +Sidecar V2 是 OpenClaw 的 API 代理服务,实现多 Provider 池管理、负载均衡、429 冷却、RPM 队列控流。 + +## 核心功能 +- **Provider 池管理**:主池 (primary) + 备用池 (fallback),支持动态增删 Provider +- **429 冷却**:检测 429 → 自动冷却 → 指数退避 → 自动恢复 +- **按 Provider 独立 RPM 限流**:每个 Provider 独立的 Token Bucket +- **路由策略**:主池优先 → 备用池兜底 → 全部耗尽返 503 +- **WebUI 管理**:Dashboard 仪表盘 + Provider CRUD +- **用量统计**:Token 用量 + 费用统计 + 每小时/每日聚合 +- **API Key 加密**:AES-256-GCM 加密存储 + +## 架构 + +``` +OpenClaw → Sidecar V2 (port 9190) → 路由 → 主池 Provider 1,2,3... + ↘ 备池 Provider 4,5... + ↘ 全部耗尽 → 503 +``` + +## 快速开始 + +```bash +# 设置加密密钥 (64位十六进制) +export SIDECAR_ENCRYPTION_KEY="0000111122223333444455556666777788889999aaaabbbbccccddddeeeeffff" + +# 启动服务 +python3 main.py + +# OR via uvicorn +python3 -m uvicorn server:app --host 127.0.0.1 --port 9190 +``` + +## WebUI +访问 http://127.0.0.1:9190/dashboard + +## API 端点 + +### Admin API +- `GET /api/admin/backends` — 列出所有 Provider +- `POST /api/admin/backends` — 添加 Provider +- `PUT /api/admin/backends/{id}` — 更新 Provider +- `DELETE /api/admin/backends/{id}` — 删除 Provider +- `GET /api/admin/pools` — 池状态汇总 +- `GET /api/admin/stats/total` — 总计统计 +- `GET /api/admin/stats/hourly` — 每小时用量 +- `GET /api/admin/stats/daily` — 每日聚合 +- `GET /api/admin/stats/cooldown` — 冷却事件历史 +- `GET /api/admin/config` — 系统配置 + +### 代理 API (OpenAI 兼容) +- `POST /v1/chat/completions` +- `POST /v1/completions` +- `POST /v1/embeddings` +- `GET /v1/models` + +### 监控 +- `GET /health` — 健康检查 +- `GET /dashboard/sse` — Dashboard 实时数据流 (SSE) + +## 环境变量 + +| 变量 | 默认值 | 说明 | +|------|--------|------| +| SIDECAR_HOST | 127.0.0.1 | 监听地址 | +| SIDECAR_PORT | 9190 | 监听端口 | +| SIDECAR_ENCRYPTION_KEY | (必填) | API Key 加密密钥 (64 hex chars) | +| SIDECAR_DB_PATH | ./data/sidecar_v2.db | SQLite 数据库路径 | +| SIDECAR_RATE_RPM | 40 | 默认 RPM 限制 | +| SIDECAR_COOLDOWN_BASE | 30 | 冷却基础时长 (秒) | +| SIDECAR_COOLDOWN_MAX | 600 | 冷却最大时长 (秒) | + +## 存储 +- SQLite (WAL 模式) +- 表:backends, backend_usage_logs, cooldown_events, backend_health, system_config, daily_stats \ No newline at end of file diff --git a/services/nvidia_sidecar/__init__.py b/services/nvidia_sidecar/__init__.py new file mode 100644 index 0000000..935ac6f --- /dev/null +++ b/services/nvidia_sidecar/__init__.py @@ -0,0 +1 @@ +"""Sidecar V2 — Multi-pool provider proxy with cooldown, rate limiting, and WebUI management.""" \ No newline at end of file diff --git a/services/nvidia_sidecar/config.py b/services/nvidia_sidecar/config.py new file mode 100644 index 0000000..4c3d951 --- /dev/null +++ b/services/nvidia_sidecar/config.py @@ -0,0 +1,165 @@ +"""System configuration management for Sidecar V2.""" + +import os +import json +from dataclasses import dataclass, field, asdict +from typing import Optional + + +@dataclass +class Config: + """Sidecar V2 runtime configuration. + + Sources (priority order): + 1. Environment variables (highest) + 2. system_config table in SQLite + 3. Defaults defined here + """ + + # Listen + host: str = "127.0.0.1" + port: int = 9190 + metrics_port: int = 9191 + + # Queue + queue_max_depth: int = 500 + queue_timeout_seconds: float = 30.0 + + # Provider + default_rpm_limit: int = 40 + + # Cooldown + cooldown_base_seconds: float = 30.0 + cooldown_max_seconds: float = 600.0 + cooldown_exponential_backoff: bool = True + + # Emergency channel: RPM fraction when all pools exhausted + emergency_rpm_fraction: float = 0.10 + + # Health check + health_check_interval_seconds: int = 60 + health_check_timeout_seconds: int = 10 + health_probe_endpoint: str = "/v1/models" + + # Admin auth + admin_token: str = "" + + # Encryption + encryption_key: str = "" + + # Logging + log_level: str = "INFO" + + # Database + db_path: str = "" + backup_dir: str = "" + backup_retention_days: int = 7 + + # Rate limiter + rate_limiter_refill_interval_ms: int = 50 + + # Router + router_refresh_interval_seconds: float = 5.0 + + # Max pool-internal retries + max_pool_retries: int = 5 + + # Pre-check cooldown threshold (seconds remaining) + cooldown_precheck_threshold_seconds: float = 10.0 + + # Dashboard + dashboard_sse_interval_seconds: float = 1.0 + + # Stats + stats_refresh_interval_seconds: float = 30.0 + + # Request timeout + default_request_timeout_seconds: int = 120 + + @classmethod + def from_env(cls) -> "Config": + """Load configuration from environment variables.""" + c = cls() + + # Listen + c.host = os.getenv("SIDECAR_HOST", c.host) + c.port = int(os.getenv("SIDECAR_PORT", str(c.port))) + c.metrics_port = int(os.getenv("SIDECAR_METRICS_PORT", str(c.metrics_port))) + + # Queue + c.queue_max_depth = int(os.getenv("SIDECAR_QUEUE_MAX", str(c.queue_max_depth))) + c.queue_timeout_seconds = float( + os.getenv("SIDECAR_QUEUE_TIMEOUT", str(c.queue_timeout_seconds)) + ) + + # Provider + c.default_rpm_limit = int( + os.getenv("SIDECAR_RATE_RPM", str(c.default_rpm_limit)) + ) + + # Cooldown + c.cooldown_base_seconds = float( + os.getenv("SIDECAR_COOLDOWN_BASE", str(c.cooldown_base_seconds)) + ) + c.cooldown_max_seconds = float( + os.getenv("SIDECAR_COOLDOWN_MAX", str(c.cooldown_max_seconds)) + ) + + # Admin + c.admin_token = os.getenv("SIDECAR_ADMIN_TOKEN", c.admin_token) + + # Encryption + c.encryption_key = os.getenv("SIDECAR_ENCRYPTION_KEY", c.encryption_key) + + # Logging + c.log_level = os.getenv("LOG_LEVEL", c.log_level).upper() + + # Database + c.db_path = os.getenv( + "SIDECAR_DB_PATH", + os.path.join(os.getcwd(), "data", "sidecar_v2.db"), + ) + c.backup_dir = os.getenv( + "SIDECAR_BACKUP_DIR", + os.path.join(os.getcwd(), "data", "backups"), + ) + + # V1 compatibility: migrate env vars + c._migrate_v1_env() + + return c + + def _migrate_v1_env(self) -> None: + """Migrate V1 environment variables to V2 defaults.""" + # V1 UPSTREAM endpoint + upstream = os.getenv("SIDECAR_UPSTREAM") + api_key = os.getenv("SIDECAR_API_KEY") + if api_key and self.encryption_key: + # These will be used during initial migration + os.environ["_SIDECAR_V1_API_KEY"] = api_key + os.environ["_SIDECAR_V1_UPSTREAM"] = upstream or "https://integrate.api.nvidia.com/v1" + + def to_db_dict(self) -> dict: + """Serialize to dict for system_config storage.""" + result = {} + for key, value in asdict(self).items(): + if isinstance(value, bool): + result[key] = "true" if value else "false" + elif isinstance(value, (int, float)): + result[key] = str(value) + else: + result[key] = value + return result + + @classmethod + def merge_db(cls, base: "Config", db_config: dict) -> "Config": + """Merge DB config into base config (env vars already applied to base).""" + for key, value in base.__dict__.items(): + if key in db_config and key not in os.environ: + # DB values only apply when no env var override + setattr(base, key, type(value)(db_config[key])) + return base + + +# Singleton +config = Config.from_env() \ No newline at end of file diff --git a/services/nvidia_sidecar/cooldown_manager.py b/services/nvidia_sidecar/cooldown_manager.py new file mode 100644 index 0000000..b11a754 --- /dev/null +++ b/services/nvidia_sidecar/cooldown_manager.py @@ -0,0 +1,116 @@ +"""429 Cooldown management for backends using exponential backoff.""" + +import time +from datetime import datetime, timezone +import structlog +from config import config +from storage.backend_store import set_backend_cooldown, clear_backend_cooldown +from storage.cooldown_store import log_cooldown_event, end_cooldown_event + +logger = structlog.get_logger("sidecar_v2.cooldown_manager") + + +def calculate_cooldown(consecutive_count: int) -> float: + """Calculate cooldown duration using exponential backoff. + + Formula: base * 2^(consecutive-1), capped at max. + """ + base = config.cooldown_base_seconds + max_seconds = config.cooldown_max_seconds + if config.cooldown_exponential_backoff: + duration = base * (2 ** (consecutive_count - 1)) + else: + duration = base * consecutive_count + return min(duration, max_seconds) + + +def start_cooldown(backend_id: str, consecutive_count: int) -> float: + """Start cooldown for a backend after 429. + + Returns: cooldown end timestamp. + """ + duration = calculate_cooldown(consecutive_count) + cooldown_until_ts = time.time() + duration + cooldown_until = time.strftime( + "%Y-%m-%dT%H:%M:%SZ", time.gmtime(cooldown_until_ts) + ) + + set_backend_cooldown(backend_id, cooldown_until, consecutive_count) + log_cooldown_event( + backend_id=backend_id, + consecutive_count=consecutive_count, + cooldown_seconds=int(duration), + response_summary=f"429 cooldown triggered (consecutive #{consecutive_count})", + ) + + logger.info( + "cooldown_started", + backend_id=backend_id, + duration=round(duration, 1), + consecutive=consecutive_count, + ) + return duration + + +def check_and_clear_cooldown(backend_id: str) -> bool: + """Check if cooldown has expired for a backend. + + Returns True if cooldown was cleared (backend is back online). + """ + from storage.backend_store import get_backend + backend = get_backend(backend_id, decrypt_key=False) + if backend is None: + return False + + if backend.status != "cooling": + return False + + cooldown_until = backend.cooldown_until + if not cooldown_until: + clear_backend_cooldown(backend_id) + return True + + # Parse cooldown_until as ISO timestamp + try: + dt = datetime.fromisoformat(cooldown_until.replace("Z", "+00:00")) + cooldown_ts = dt.timestamp() + except ValueError: + # If parsing fails, clear and move on + clear_backend_cooldown(backend_id) + return True + + now = time.time() + if now >= cooldown_ts: + clear_backend_cooldown(backend_id) + end_cooldown_event(backend_id) + logger.info("cooldown_cleared", backend_id=backend_id) + return True + + remaining = cooldown_ts - now + logger.debug("cooldown_active", backend_id=backend_id, remaining_seconds=round(remaining, 1)) + return False + + +def precheck_cooldown(backend_id: str) -> bool: + """Check if backend should be skipped due to near-expiry cooldown. + + If cooldown will expire within config.cooldown_precheck_threshold_seconds, + skip the backend so we don't hit it again right as it expires. + """ + from storage.backend_store import get_backend + backend = get_backend(backend_id, decrypt_key=False) + if backend is None or backend.status != "cooling": + return False + + cooldown_until = backend.cooldown_until + if not cooldown_until: + return False + + try: + dt = datetime.fromisoformat(cooldown_until.replace("Z", "+00:00")) + cooldown_ts = dt.timestamp() + except ValueError: + return False + + remaining = cooldown_ts - time.time() + return 0 < remaining <= config.cooldown_precheck_threshold_seconds \ No newline at end of file diff --git a/services/nvidia_sidecar/crypto.py b/services/nvidia_sidecar/crypto.py new file mode 100644 index 0000000..c424b4a --- /dev/null +++ b/services/nvidia_sidecar/crypto.py @@ -0,0 +1,108 @@ +"""AES-256-GCM encryption for API Key storage.""" + +import os +import secrets +import structlog +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +logger = structlog.get_logger() + +_ENCRYPTION_KEY: bytes | None = None +_cipher: AESGCM | None = None + + +def init_crypto(hex_key: str) -> None: + """Initialize the encryption module. + + Validates the key and prepares the cipher. + Raises ValueError if key is invalid. + """ + global _ENCRYPTION_KEY, _cipher + + if not hex_key: + raise ValueError("FATAL: SIDECAR_ENCRYPTION_KEY not set") + + if len(hex_key) != 64: + raise ValueError( + f"FATAL: SIDECAR_ENCRYPTION_KEY must be 64 hex chars (32 bytes), " + f"got {len(hex_key)} chars" + ) + + try: + key_bytes = bytes.fromhex(hex_key) + except ValueError: + raise ValueError( + "FATAL: SIDECAR_ENCRYPTION_KEY must be valid hexadecimal" + ) + + global _ENCRYPTION_KEY, _cipher + _ENCRYPTION_KEY = key_bytes + _cipher = AESGCM(key_bytes) + logger.info("crypto_initialized") + + +def encrypt(plaintext: str) -> str: + """Encrypt plaintext using AES-256-GCM. + + Returns: hex-encoded nonce (12 bytes) + ciphertext + tag. + Format: : + """ + if _cipher is None: + raise RuntimeError("Crypto not initialized. Call init_crypto() first.") + + nonce = secrets.token_bytes(12) + ciphertext = _cipher.encrypt(nonce, plaintext.encode("utf-8"), None) + return nonce.hex() + ":" + ciphertext.hex() + + +def decrypt(encrypted: str) -> str: + """Decrypt AES-256-GCM ciphertext. + + Args: + encrypted: Format ":" + + Returns: Decrypted plaintext string. + """ + if _cipher is None: + raise RuntimeError("Crypto not initialized. Call init_crypto() first.") + + parts = encrypted.split(":", 1) + if len(parts) != 2: + raise ValueError("Invalid encrypted format: expected nonce:ciphertext") + + nonce = bytes.fromhex(parts[0]) + ciphertext = bytes.fromhex(parts[1]) + + try: + plaintext = _cipher.decrypt(nonce, ciphertext, None) + return plaintext.decode("utf-8") + except Exception as e: + raise ValueError(f"Decryption failed: {e}") + + +def is_initialized() -> bool: + """Check if crypto has been initialized.""" + return _cipher is not None + + +def mask_api_key(api_key_plain: str) -> str: + """Mask API key for display: show first 6 + last 4 chars.""" + if len(api_key_plain) <= 10: + return api_key_plain[:2] + "****" + return api_key_plain[:6] + "****" + api_key_plain[-4:] + + +def try_decrypt_existing(encrypted_value: str) -> str | None: + """Try to decrypt an existing encrypted value. + + Returns the plaintext if successful, None if decryption fails + (e.g., encryption key was changed). + """ + try: + return decrypt(encrypted_value) + except Exception: + logger.warning( + "decrypt_existing_failed", + hint="Encryption key may have been changed, existing keys unrecoverable" + ) + return None \ No newline at end of file diff --git a/services/nvidia_sidecar/dashboard.html b/services/nvidia_sidecar/dashboard.html new file mode 100644 index 0000000..410f84a --- /dev/null +++ b/services/nvidia_sidecar/dashboard.html @@ -0,0 +1,605 @@ + + + + + +Sidecar V2 — Provider Pool Dashboard + + + + +
+ + + + +
+ +
+
+
+
+
+ + +
+
+

Provider Backends

+ +
+ + + + + +
NameLabelPoolStatusRPMModelsActions
+
+ + +
+

Hourly Usage

+
+ +
+ + + + + +
HourBackendModelRequestsErrorsTokensCostAvg Latency
+ +

Daily Aggregation

+ + + + + +
DatePoolRequestsErrorsTokensCostBackends
+
+ + +
+

Cooldown Event History

+ + + + + +
TimeBackendConsecutive 429sDurationSummary
+
+
+
+ + + + + + + \ No newline at end of file diff --git a/services/nvidia_sidecar/main.py b/services/nvidia_sidecar/main.py new file mode 100644 index 0000000..22df396 --- /dev/null +++ b/services/nvidia_sidecar/main.py @@ -0,0 +1,17 @@ +"""Sidecar V2 entry point.""" + +import uvicorn +from config import config + + +def main(): + uvicorn.run( + "server:app", + host=config.host, + port=config.port, + log_level=config.log_level.lower(), + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/services/nvidia_sidecar/pool_manager.py b/services/nvidia_sidecar/pool_manager.py new file mode 100644 index 0000000..29c519f --- /dev/null +++ b/services/nvidia_sidecar/pool_manager.py @@ -0,0 +1,83 @@ +"""Provider pool management: primary / fallback pool routing.""" + +import structlog +from typing import Optional + +from storage.backend_store import list_backends, get_pool_stats +from storage.models import Backend + +logger = structlog.get_logger("sidecar_v2.pool_manager") + + +class PoolManager: + """Manages provider pools and selects healthy backends for a given model. + + Priority: primary pool → fallback pool. + Within a pool: healthy backends only, sorted by availability. + """ + + def __init__(self): + self._pool_order = ["primary", "fallback"] + + def get_available_backends( + self, canonical_model: str, pool: Optional[str] = None + ) -> list[Backend]: + """Get all healthy, enabled backends that serve a model, in pool order. + + Args: + canonical_model: Canonical model name to match. + pool: Optional pool filter (primary/fallback). None = all pools. + + Returns: + List of ready backends sorted by pool priority, then RPM utilization. + """ + backends: list[Backend] = [] + + pools_to_check = [pool] if pool else self._pool_order + for p in pools_to_check: + pool_backends = list_backends(pool=p, enabled_only=True, decrypt_key=True) + for b in pool_backends: + if b.status == "healthy" and b.has_model(canonical_model): + backends.append(b) + if pool: + break + + return backends + + def get_any_healthy_backends(self, pool: Optional[str] = None) -> list[Backend]: + """Get all healthy, enabled backends regardless of model.""" + backends: list[Backend] = [] + pools_to_check = [pool] if pool else self._pool_order + for p in pools_to_check: + pool_backends = list_backends(pool=p, enabled_only=True, decrypt_key=True) + for b in pool_backends: + if b.status == "healthy": + backends.append(b) + if pool: + break + return backends + + def get_pool_status(self) -> dict: + """Get pool summary for dashboard.""" + stats = get_pool_stats() + result = {} + for pool in self._pool_order: + s = stats.get(pool, {"total": 0, "enabled": 0, "healthy": 0, "cooling": 0, "error": 0}) + result[pool] = s + # Also include any other pools + for pool, s in stats.items(): + if pool not in result: + result[pool] = s + return result + + def is_pool_available(self, canonical_model: str, pool: str = "primary") -> bool: + """Check if a pool has any healthy backends for a model.""" + backends = self.get_available_backends(canonical_model, pool=pool) + return len(backends) > 0 + + def is_any_pool_available(self, canonical_model: str) -> bool: + """Check if any pool has healthy backends for a model.""" + for pool in self._pool_order: + if self.is_pool_available(canonical_model, pool): + return True + return False \ No newline at end of file diff --git a/services/nvidia_sidecar/proxy.py b/services/nvidia_sidecar/proxy.py new file mode 100644 index 0000000..b419e4f --- /dev/null +++ b/services/nvidia_sidecar/proxy.py @@ -0,0 +1,321 @@ +"""Proxy request handling for Sidecar V2 — multi-pool routing + cooldown + rate limiting.""" + +import asyncio +import json +import time +from typing import Any, Optional + +import httpx +import structlog +from fastapi import Request +from fastapi.responses import JSONResponse, Response, StreamingResponse + +from config import config +from pool_manager import PoolManager +from rate_limiter import PerBackendRateLimiter +from router import Router +from cooldown_manager import start_cooldown, check_and_clear_cooldown +from storage.models import Backend +from storage.usage_store import record_usage + +logger: structlog.stdlib.BoundLogger = structlog.get_logger("sidecar_v2.proxy") + + +def extract_model(body: dict[str, Any]) -> str: + """Extract model identifier from request body.""" + return str(body.get("model", "unknown")) + + +def build_error_response(status: int, message: str, error_type: str = "") -> JSONResponse: + """Build a standard error response.""" + return JSONResponse( + status_code=status, + content={ + "error": { + "message": message, + "type": error_type or f"Error_{status}", + } + }, + ) + + +async def forward_to_backend( + backend: Backend, + method: str, + path: str, + body: bytes | None, + headers: dict[str, str], + stream: bool = False, +) -> httpx.Response: + """Forward a request to a specific backend.""" + upstream_url = backend.api_base_url.rstrip("/") + path + + forward_headers = { + k: v + for k, v in headers.items() + if k.lower() not in ("host", "content-length", "transfer-encoding") + } + + if backend.api_key_plain: + forward_headers["authorization"] = f"Bearer {backend.api_key_plain}" + elif "authorization" not in {k.lower() for k in forward_headers}: + forward_headers["authorization"] = "Bearer nvidia" + + timeout = httpx.Timeout(backend.timeout_seconds) + + async with httpx.AsyncClient(timeout=timeout) as client: + req = client.build_request( + method=method, + url=upstream_url, + headers=forward_headers, + content=body, + ) + return await client.send(req, stream=stream) + + +def build_response(resp: httpx.Response) -> Response: + """Convert httpx.Response to FastAPI Response.""" + content_type = resp.headers.get("content-type", "") + headers = { + k: v + for k, v in resp.headers.items() + if k.lower() not in ("content-encoding", "transfer-encoding") + } + + if "text/event-stream" in content_type or "stream" in content_type: + return StreamingResponse( + content=resp.aiter_bytes(), + status_code=resp.status_code, + headers=headers, + media_type=content_type or "text/event-stream", + ) + + return Response( + content=resp.content, + status_code=resp.status_code, + headers=headers, + media_type=content_type or "application/json", + ) + + +def extract_usage_from_response( + resp: httpx.Response, + resp_json: dict[str, Any], + model: str, +) -> tuple[int, int, int]: + """Extract token usage from response body (OpenAI-compatible).""" + usage = resp_json.get("usage", {}) + prompt_tokens = usage.get("prompt_tokens", 0) or 0 + completion_tokens = usage.get("completion_tokens", 0) or 0 + + # Try streaming chunks: aggregate from choices + if not prompt_tokens and not completion_tokens: + choices = resp_json.get("choices", []) + for choice in choices: + if isinstance(choice, dict): + tokens = choice.get("usage", {}) + prompt_tokens += tokens.get("prompt_tokens", 0) or 0 + completion_tokens += tokens.get("completion_tokens", 0) or 0 + + total_tokens = prompt_tokens + completion_tokens + if total_tokens == 0: + total_tokens = usage.get("total_tokens", 0) or 0 + + return prompt_tokens, completion_tokens, total_tokens + + +def calculate_cost( + backend: Backend, + model: str, + prompt_tokens: int, + completion_tokens: int, +) -> float: + """Calculate cost using backend's model pricing.""" + cost_info = backend.get_model_cost(model) + input_cost = cost_info.get("input", 0.0) + output_cost = cost_info.get("output", 0.0) + # Costs are per token + return (prompt_tokens * input_cost + completion_tokens * output_cost) + + +async def handle_proxy_request( + pool_manager: PoolManager, + rate_limiter: PerBackendRateLimiter, + router: Router, + request: Request, + path: str, +) -> Response: + """Main proxy handler: multi-pool routing with cooldown and rate limiting. + + Flow: + 1. Extract model → canonical name + 2. Pick backend via Router (primary → fallback) + 3. Forward request + 4. If 429 → cooldown backend, retry with another + 5. If all pools exhausted → emergency mode + 6. Track usage + """ + start_time = time.monotonic() + + body_bytes: bytes = await request.body() + raw_headers: dict[str, str] = dict(request.headers) + + body_json: dict[str, Any] = {} + try: + if body_bytes: + parsed = json.loads(body_bytes) + if isinstance(parsed, dict): + body_json = parsed + except (ValueError, TypeError): + body_json = {} + + canonical_model = extract_model(body_json) + is_stream = body_json.get("stream", False) + + # Try with pool routing + max_retries = config.max_pool_retries + for attempt in range(max_retries): + # Check and clear expired cooldowns before picking + _refresh_cooldowns(pool_manager) + + backend = router.pick_backend(canonical_model) + if backend is None: + break # No backend available, fall through to emergency + + try: + resp = await forward_to_backend( + backend=backend, + method=request.method, + path=path, + body=body_bytes if body_bytes else None, + headers=raw_headers, + stream=is_stream, + ) + elapsed_ms = int((time.monotonic() - start_time) * 1000) + + # Handle 429 — cooldown and retry + if resp.status_code == 429: + new_count = backend.consecutive_429_count + 1 + start_cooldown(backend.id, new_count) + + resp_body = "" + try: + resp_body = resp.text[:200] + except Exception: + pass + + logger.warning( + "backend_429_cooldown", + backend_id=backend.id, + pool=backend.pool, + consecutive=new_count, + model=canonical_model, + ) + + # Track the error + record_usage( + backend_id=backend.id, + model=canonical_model, + prompt_tokens=0, + completion_tokens=0, + cost=0.0, + latency_ms=elapsed_ms, + is_error=True, + ) + + continue # Retry with another backend + + # Success — track usage + resp_json: dict[str, Any] = {} + try: + if not is_stream and resp.content: + resp_json = json.loads(resp.content) + except (ValueError, TypeError): + pass + + prompt_tokens, completion_tokens, total_tokens = extract_usage_from_response( + resp, resp_json, canonical_model + ) + cost = calculate_cost( + backend, canonical_model, prompt_tokens, completion_tokens + ) + + record_usage( + backend_id=backend.id, + model=canonical_model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + cost=cost, + latency_ms=elapsed_ms, + ) + + logger.info( + "request_completed", + backend_id=backend.id, + pool=backend.pool, + model=canonical_model, + status=resp.status_code, + tokens=total_tokens, + cost=round(cost, 6), + elapsed_ms=elapsed_ms, + ) + + return build_response(resp) + + except httpx.TimeoutException: + logger.warning( + "backend_timeout", + backend_id=backend.id, + model=canonical_model, + ) + continue + except (httpx.ConnectError, httpx.RemoteProtocolError) as exc: + logger.warning( + "backend_connection_error", + backend_id=backend.id, + model=canonical_model, + error=str(exc), + ) + continue + except Exception as exc: + logger.error( + "proxy_error", + backend_id=backend.id, + model=canonical_model, + error=str(exc), + ) + continue + + # All backends exhausted — emergency rate-limited passthrough + emergency_rpm = int(config.default_rpm_limit * config.emergency_rpm_fraction) + if emergency_rpm < 1: + emergency_rpm = 1 + + logger.warning( + "all_pools_exhausted_emergency", + model=canonical_model, + emergency_rpm=emergency_rpm, + ) + + # Emergency: just return a clear error telling OpenClaw to use its fallback + return build_error_response( + 503, + "All provider pools exhausted. OpenClaw fallback chain should activate.", + "AllPoolsExhausted", + ) + + +def _refresh_cooldowns(pool_manager: PoolManager) -> None: + """Check and clear expired cooldowns for all active backends.""" + for pool in ["primary", "fallback"]: + backends = pool_manager.get_any_healthy_backends(pool=pool) + for backend in backends: + # Only check backends in non-healthy state + pass + + # Actually check all backends including cooling ones + from storage.backend_store import list_backends + backends = list_backends(decrypt_key=False) + for backend in backends: + if backend.status == "cooling": + check_and_clear_cooldown(backend.id) \ No newline at end of file diff --git a/services/nvidia_sidecar/rate_limiter.py b/services/nvidia_sidecar/rate_limiter.py new file mode 100644 index 0000000..50b2e4c --- /dev/null +++ b/services/nvidia_sidecar/rate_limiter.py @@ -0,0 +1,111 @@ +"""Per-backend rate limiter using token bucket algorithm.""" + +import threading +import time +from typing import Any + + +class PerBackendRateLimiter: + """Manages independent token buckets for each backend. + + Thread-safe. Each backend gets its own bucket with configurable RPM. + """ + + def __init__(self, refill_interval_ms: int = 50): + self._buckets: dict[str, _TokenBucket] = {} + self._lock = threading.Lock() + self._refill_interval_ms = refill_interval_ms + + def ensure_bucket(self, backend_id: str, rpm_limit: int) -> None: + """Create or update a bucket for a backend.""" + with self._lock: + if backend_id in self._buckets: + existing = self._buckets[backend_id] + existing.update_rate(rpm_limit) + else: + self._buckets[backend_id] = _TokenBucket( + rate=rpm_limit / 60.0, + capacity=max(rpm_limit, 1), + ) + + def remove_bucket(self, backend_id: str) -> None: + """Remove a backend's bucket.""" + with self._lock: + self._buckets.pop(backend_id, None) + + def consume(self, backend_id: str, rpm_limit: int, tokens: int = 1) -> bool: + """Try to consume tokens for a backend. Returns True if allowed. + + Auto-creates the bucket if needed. + """ + self.ensure_bucket(backend_id, rpm_limit) + + with self._lock: + bucket = self._buckets.get(backend_id) + if bucket is None: + return False + + return bucket.consume(tokens) + + def get_status(self, backend_id: str) -> dict[str, Any] | None: + """Get bucket status for a backend.""" + with self._lock: + bucket = self._buckets.get(backend_id) + if bucket is None: + return None + return bucket.get_status() + + def get_all_status(self) -> dict[str, dict[str, Any]]: + """Get status of all buckets.""" + with self._lock: + return {bid: b.get_status() for bid, b in self._buckets.items()} + + +class _TokenBucket: + """Internal token bucket with refill.""" + + def __init__(self, rate: float, capacity: int): + self._rate = float(rate) + self._capacity = int(capacity) + self._tokens = float(capacity) + self._last_refill = time.monotonic() + self._lock = threading.Lock() + + def _refill(self) -> None: + now = time.monotonic() + elapsed = now - self._last_refill + if elapsed > 0 and self._rate > 0: + self._tokens = min(self._tokens + elapsed * self._rate, float(self._capacity)) + self._last_refill = now + + def consume(self, tokens: int = 1) -> bool: + if tokens <= 0: + return True + with self._lock: + self._refill() + if self._tokens >= tokens: + self._tokens -= tokens + return True + return False + + def update_rate(self, rpm_limit: int) -> None: + new_rate = rpm_limit / 60.0 + with self._lock: + self._refill() + self._rate = new_rate + self._capacity = max(rpm_limit, 1) + self._tokens = min(self._tokens, float(self._capacity)) + + def get_status(self) -> dict[str, Any]: + with self._lock: + self._refill() + rate_per_minute = self._rate * 60.0 + utilization = 0.0 if self._capacity == 0 else ( + (self._capacity - self._tokens) / self._capacity + ) + return { + "tokens": round(self._tokens, 2), + "capacity": self._capacity, + "rate_per_minute": round(rate_per_minute, 1), + "utilization": round(utilization, 4), + } \ No newline at end of file diff --git a/services/nvidia_sidecar/router.py b/services/nvidia_sidecar/router.py new file mode 100644 index 0000000..ef2fa6c --- /dev/null +++ b/services/nvidia_sidecar/router.py @@ -0,0 +1,62 @@ +"""Model → Backend routing logic for Sidecar V2.""" + +import structlog +from typing import Optional + +from storage.models import Backend +from pool_manager import PoolManager +from rate_limiter import PerBackendRateLimiter + +logger = structlog.get_logger("sidecar_v2.router") + + +class Router: + """Routes model requests to the best available backend. + + Pick strategy: + 1. Primary pool → healthy backends supporting the model + 2. Rate-limiter check → skip if RPM exhausted + 3. Fallback pool → repeat above + 4. If all exhausted → return None (caller handles emergency) + """ + + def __init__(self, pool_manager: PoolManager, rate_limiter: PerBackendRateLimiter): + self._pool_manager = pool_manager + self._rate_limiter = rate_limiter + + def pick_backend(self, canonical_model: str) -> Optional[Backend]: + """Pick the best available backend for a model. + + Tries primary pool first, then fallback. + Within each pool, skips backends at RPM limit. + Returns None if no backend available. + """ + # Try pools in order + for pool in ["primary", "fallback"]: + backends = self._pool_manager.get_available_backends( + canonical_model, pool=pool + ) + for backend in backends: + # Rate-limit check + if self._rate_limiter.consume( + backend.id, backend.rpm_limit + ): + return backend + # Skip this backend, try next + logger.debug( + "backend_rate_limited", + backend_id=backend.id, + pool=pool, + model=canonical_model, + ) + + if not backends: + logger.debug("pool_exhausted", pool=pool, model=canonical_model) + else: + logger.debug("pool_rpm_exhausted", pool=pool, model=canonical_model) + + return None + + def get_all_pools_exhausted_info(self, canonical_model: str) -> bool: + """Check if ALL pools are exhausted for a model.""" + return not self._pool_manager.is_any_pool_available(canonical_model) \ No newline at end of file diff --git a/services/nvidia_sidecar/server.py b/services/nvidia_sidecar/server.py new file mode 100644 index 0000000..9b850fb --- /dev/null +++ b/services/nvidia_sidecar/server.py @@ -0,0 +1,477 @@ +"""Sidecar V2 — FastAPI server with multi-pool routing, admin API, dashboard SSE.""" + +import asyncio +import time +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from typing import Any + +import structlog +from fastapi import Depends, FastAPI, HTTPException, Request, Response +from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, StreamingResponse +from fastapi.staticfiles import StaticFiles + +from config import config as app_config +from crypto import init_crypto +from pool_manager import PoolManager +from rate_limiter import PerBackendRateLimiter +from router import Router +from proxy import handle_proxy_request + +from storage.db import init_db, create_tables, run_integrity_check +from storage.backend_store import ( + create_backend, get_backend, list_backends, update_backend, + delete_backend, get_pool_stats, +) +from storage.usage_store import get_total_stats, get_hourly_usage, get_daily_stats, aggregate_daily_stats +from storage.cooldown_store import get_cooldown_history +from storage.config_store import get_config, set_config, list_configs, delete_config +from storage.models import Backend, ModelMapping + +import os + +structlog.configure( + processors=[ + structlog.stdlib.filter_by_level, + structlog.stdlib.add_logger_name, + structlog.stdlib.add_log_level, + structlog.stdlib.PositionalArgumentsFormatter(), + structlog.processors.TimeStamper(fmt="iso"), + structlog.processors.StackInfoRenderer(), + structlog.processors.format_exc_info, + structlog.processors.UnicodeDecoder(), + structlog.dev.ConsoleRenderer(), + ], + context_class=dict, + logger_factory=structlog.stdlib.LoggerFactory(), + wrapper_class=structlog.stdlib.BoundLogger, + cache_logger_on_first_use=True, +) +logger: structlog.stdlib.BoundLogger = structlog.get_logger("sidecar_v2.server") + + +# ────────────────────────────────────── +# Global runtime state +# ────────────────────────────────────── +pool_manager: PoolManager | None = None +rate_limiter: PerBackendRateLimiter | None = None +router: Router | None = None +start_time: float = 0.0 + + +def get_pm() -> PoolManager: + assert pool_manager is not None + return pool_manager + +def get_rl() -> PerBackendRateLimiter: + assert rate_limiter is not None + return rate_limiter + +def get_router() -> Router: + assert router is not None + return router + + +# ────────────────────────────────────── +# Lifespan +# ────────────────────────────────────── +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, Any]: + global pool_manager, rate_limiter, router, start_time + + # Init crypto + if app_config.encryption_key: + init_crypto(app_config.encryption_key) + + # Init DB + init_db() + create_tables() + ok = run_integrity_check() + if not ok: + logger.error("db_integrity_check_failed") + + # Init runtime components + pool_manager = PoolManager() + rate_limiter = PerBackendRateLimiter( + refill_interval_ms=app_config.rate_limiter_refill_interval_ms, + ) + router = Router(pool_manager, rate_limiter) + start_time = time.time() + + # Start background tasks + health_task = asyncio.create_task(_health_check_loop()) + stats_task = asyncio.create_task(_stats_aggregation_loop()) + + logger.info( + "sidecar_v2_started", + host=app_config.host, + port=app_config.port, + metrics_port=app_config.metrics_port, + ) + + try: + yield + finally: + health_task.cancel() + stats_task.cancel() + try: + await health_task + except asyncio.CancelledError: + pass + try: + await stats_task + except asyncio.CancelledError: + pass + logger.info("sidecar_v2_stopped") + + +app = FastAPI( + title="Sidecar V2 — Multi-Pool Provider Proxy", + version="2.0.0", + lifespan=lifespan, +) + + +# ────────────────────────────────────── +# Background tasks +# ────────────────────────────────────── + +async def _health_check_loop() -> None: + """Periodically check and clear expired cooldowns.""" + from cooldown_manager import check_and_clear_cooldown + from storage.backend_store import list_backends as lb + + while True: + try: + backends = lb(decrypt_key=False) + for b in backends: + if b.status == "cooling": + check_and_clear_cooldown(b.id) + except Exception: + logger.exception("health_check_error") + await asyncio.sleep(app_config.health_check_interval_seconds) + + +async def _stats_aggregation_loop() -> None: + """Periodically aggregate daily stats.""" + while True: + try: + today = time.strftime("%Y-%m-%d", time.gmtime()) + aggregate_daily_stats(today) + except Exception: + logger.exception("stats_aggregation_error") + await asyncio.sleep(app_config.stats_refresh_interval_seconds) + + +# ────────────────────────────────────── +# Health / Metrics +# ────────────────────────────────────── +@app.get("/health") +async def health() -> dict[str, Any]: + return { + "status": "ok", + "version": "2.0.0", + "uptime_seconds": int(time.time() - start_time), + } + + +# ────────────────────────────────────── +# Dashboard SSE +# ────────────────────────────────────── +@app.get("/dashboard/sse") +async def dashboard_sse() -> StreamingResponse: + """SSE endpoint for real-time dashboard data.""" + + async def event_generator(): + while True: + try: + pool_status = pool_manager.get_pool_status() + total_stats = get_total_stats() + all_backends = list_backends(decrypt_key=False) + + backends_list = [] + for b in all_backends: + rl_status = rate_limiter.get_status(b.id) + backends_list.append({ + "id": b.id, + "name": b.name, + "label": b.label, + "pool": b.pool, + "enabled": b.enabled, + "status": b.status, + "rpm_limit": b.rpm_limit, + "cooldown_until": b.cooldown_until, + "consecutive_429_count": b.consecutive_429_count, + "model_count": len(b.model_mappings), + "rate_limiter": rl_status, + }) + + snapshot = { + "type": "snapshot", + "pool": pool_status, + "total": total_stats, + "backends": backends_list, + "uptime_seconds": int(time.time() - start_time), + "timestamp": time.time(), + } + yield f"data: {__import__('json').dumps(snapshot)}\n\n" + except Exception: + logger.exception("sse_error") + + await asyncio.sleep(app_config.dashboard_sse_interval_seconds) + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + +# ────────────────────────────────────── +# Admin: Backend CRUD +# ────────────────────────────────────── + +@app.get("/api/admin/backends") +async def admin_list_backends(pool: str | None = None) -> list[dict]: + """List all backends with masked keys.""" + backends = list_backends(pool=pool, decrypt_key=True) + return [b.to_dict(mask_key=True) for b in backends] + + +@app.get("/api/admin/backends/{backend_id}") +async def admin_get_backend(backend_id: str) -> dict: + """Get a single backend (key masked).""" + b = get_backend(backend_id, decrypt_key=True) + if b is None: + raise HTTPException(404, "Backend not found") + return b.to_dict(mask_key=True) + + +@app.post("/api/admin/backends") +async def admin_create_backend(body: dict[str, Any]) -> dict: + """Create a new backend.""" + required = ["name", "api_base_url", "api_key"] + for field in required: + if field not in body: + raise HTTPException(400, f"Missing required field: {field}") + + model_mappings_raw = body.get("model_mappings", {}) + model_mappings = {} + for canonical_name, mm in model_mappings_raw.items(): + model_mappings[canonical_name] = ModelMapping.from_dict(mm) + + backend = Backend( + name=body["name"], + label=body.get("label", ""), + api_base_url=body["api_base_url"], + api_key_plain=body["api_key"], + api=body.get("api", "openai-completions"), + timeout_seconds=body.get("timeout_seconds", 120), + rpm_limit=body.get("rpm_limit", app_config.default_rpm_limit), + pool=body.get("pool", "primary"), + enabled=body.get("enabled", True), + model_mappings=model_mappings, + source=body.get("source", "webui"), + ) + + created = create_backend(backend) + return created.to_dict(mask_key=True) + + +@app.put("/api/admin/backends/{backend_id}") +async def admin_update_backend(backend_id: str, body: dict[str, Any]) -> dict: + """Update a backend.""" + updates = dict(body) + + # Handle model_mappings + if "model_mappings" in updates: + raw = updates["model_mappings"] + updates["model_mappings"] = { + k: ModelMapping.from_dict(v) for k, v in raw.items() + } + + # Handle api_key + if "api_key" in updates: + updates["api_key_plain"] = updates.pop("api_key") + + updated = update_backend(backend_id, updates) + if updated is None: + raise HTTPException(404, "Backend not found") + return updated.to_dict(mask_key=True) + + +@app.delete("/api/admin/backends/{backend_id}") +async def admin_delete_backend(backend_id: str) -> dict: + """Delete a backend.""" + ok = delete_backend(backend_id) + if not ok: + raise HTTPException(404, "Backend not found") + return {"status": "deleted", "id": backend_id} + + +# ────────────────────────────────────── +# Admin: Pool Status +# ────────────────────────────────────── + +@app.get("/api/admin/pools") +async def admin_pool_status() -> dict: + """Get pool summary.""" + return pool_manager.get_pool_status() + + +# ────────────────────────────────────── +# Admin: Usage / Stats +# ────────────────────────────────────── + +@app.get("/api/admin/stats/total") +async def admin_total_stats() -> dict: + """Get aggregate usage stats.""" + return get_total_stats() + + +@app.get("/api/admin/stats/hourly") +async def admin_hourly_usage( + backend_id: str | None = None, + hours: int = 168, +) -> list[dict]: + """Get hourly usage data.""" + since = None + if hours > 0: + since = time.strftime( + "%Y-%m-%dT%H:%M:%SZ", + time.gmtime(time.time() - hours * 3600), + ) + return get_hourly_usage(backend_id=backend_id, since=since, limit=hours) + + +@app.get("/api/admin/stats/daily") +async def admin_daily_stats(days: int = 30) -> list[dict]: + """Get daily aggregated stats.""" + return get_daily_stats(days=days) + + +@app.get("/api/admin/stats/cooldown") +async def admin_cooldown_history( + backend_id: str | None = None, + limit: int = 50, +) -> list[dict]: + """Get cooldown event history.""" + return get_cooldown_history(backend_id=backend_id, limit=limit) + + +# ────────────────────────────────────── +# Admin: System Config +# ────────────────────────────────────── + +@app.get("/api/admin/config") +async def admin_get_all_config() -> list[dict]: + """List all system config entries.""" + return list_configs() + + +@app.get("/api/admin/config/{key}") +async def admin_get_config(key: str) -> dict: + """Get a single config value.""" + value = get_config(key) + if value is None: + raise HTTPException(404, "Config not found") + return {"key": key, "value": value} + + +@app.put("/api/admin/config/{key}") +async def admin_set_config(key: str, body: dict[str, Any]) -> dict: + """Set a config value.""" + value = str(body.get("value", "")) + description = str(body.get("description", "")) + set_config(key, value, description) + return {"key": key, "value": value} + + +@app.delete("/api/admin/config/{key}") +async def admin_delete_config(key: str) -> dict: + """Delete a config entry.""" + ok = delete_config(key) + if not ok: + raise HTTPException(404, "Config not found") + return {"status": "deleted", "key": key} + + +# ────────────────────────────────────── +# Dashboard HTML +# ────────────────────────────────────── + +@app.get("/dashboard") +async def dashboard_html() -> HTMLResponse: + """Serve the dashboard WebUI.""" + dashboard_path = os.path.join( + os.path.dirname(__file__), "dashboard.html" + ) + if os.path.exists(dashboard_path): + with open(dashboard_path, "r") as f: + return HTMLResponse(f.read()) + return HTMLResponse("

Dashboard not found

", status_code=404) + + +# ────────────────────────────────────── +# Proxy Endpoints +# ────────────────────────────────────── + +@app.post("/v1/chat/completions") +async def chat_completions(request: Request) -> Response: + return await handle_proxy_request( + pool_manager, rate_limiter, router, request, "/v1/chat/completions" + ) + + +@app.post("/v1/completions") +async def completions(request: Request) -> Response: + return await handle_proxy_request( + pool_manager, rate_limiter, router, request, "/v1/completions" + ) + + +@app.post("/v1/embeddings") +async def embeddings(request: Request) -> Response: + return await handle_proxy_request( + pool_manager, rate_limiter, router, request, "/v1/embeddings" + ) + + +@app.get("/v1/models") +@app.get("/v1/models/{model_id:path}") +async def list_models(request: Request, model_id: str | None = None) -> Response: + path = f"/v1/models/{model_id}" if model_id else "/v1/models" + return await handle_proxy_request( + pool_manager, rate_limiter, router, request, path + ) + + +@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) +async def catch_all(request: Request, path: str) -> Response: + target_path = f"/{path}" if not path.startswith("/") else path + return await handle_proxy_request( + pool_manager, rate_limiter, router, request, target_path + ) + + +# ────────────────────────────────────── +# Main +# ────────────────────────────────────── + +def main() -> None: + import uvicorn + + uvicorn.run( + "server:app", + host=app_config.host, + port=app_config.port, + log_level=app_config.log_level.lower(), + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/services/nvidia_sidecar/storage/__init__.py b/services/nvidia_sidecar/storage/__init__.py new file mode 100644 index 0000000..ea64179 --- /dev/null +++ b/services/nvidia_sidecar/storage/__init__.py @@ -0,0 +1 @@ +# Sidecar V2 storage module \ No newline at end of file diff --git a/services/nvidia_sidecar/storage/backend_store.py b/services/nvidia_sidecar/storage/backend_store.py new file mode 100644 index 0000000..c571cb4 --- /dev/null +++ b/services/nvidia_sidecar/storage/backend_store.py @@ -0,0 +1,252 @@ +"""CRUD operations for Backend (provider) management.""" + +import json +import time +from typing import Optional + +from storage.db import get_connection, generate_id +from storage.models import Backend, ModelMapping +from crypto import encrypt, decrypt + + +def create_backend(backend: Backend) -> Backend: + """Create a new backend. Encrypts API key before storage.""" + if not backend.id: + backend.id = generate_id("bkd") + + now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + backend.created_at = now + backend.updated_at = now + + api_key_encrypted = encrypt(backend.api_key_plain) + + with get_connection() as conn: + conn.execute( + """INSERT INTO backends (id, name, label, api_base_url, api_key_encrypted, + api, timeout_seconds, rpm_limit, pool, enabled, status, model_mappings_json, + source, cooldown_until, consecutive_429_count, metadata_json, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + backend.id, backend.name, backend.label, backend.api_base_url, + api_key_encrypted, backend.api, backend.timeout_seconds, + backend.rpm_limit, backend.pool, 1 if backend.enabled else 0, + backend.status, json.dumps(_mappings_to_dict(backend.model_mappings)), + backend.source, backend.cooldown_until, + backend.consecutive_429_count, + json.dumps(backend.metadata), backend.created_at, backend.updated_at, + ), + ) + conn.commit() + + return backend + + +def get_backend(backend_id: str, decrypt_key: bool = True) -> Optional[Backend]: + """Get a single backend by ID.""" + with get_connection() as conn: + row = conn.execute( + "SELECT * FROM backends WHERE id = ?", (backend_id,) + ).fetchone() + + if row is None: + return None + + return _row_to_backend(row, decrypt_key=decrypt_key) + + +def list_backends( + pool: Optional[str] = None, + enabled_only: bool = False, + decrypt_key: bool = False, +) -> list[Backend]: + """List backends, optionally filtered by pool.""" + with get_connection() as conn: + if pool: + rows = conn.execute( + "SELECT * FROM backends WHERE pool = ? ORDER BY created_at", + (pool,), + ).fetchall() + else: + rows = conn.execute( + "SELECT * FROM backends ORDER BY pool, created_at" + ).fetchall() + + backends = [_row_to_backend(r, decrypt_key=decrypt_key) for r in rows] + if enabled_only: + backends = [b for b in backends if b.enabled] + return backends + + +def update_backend(backend_id: str, updates: dict) -> Optional[Backend]: + """Update backend fields. If api_key_plain is provided, re-encrypt.""" + current = get_backend(backend_id, decrypt_key=True) + if current is None: + return None + + # Apply updates + allowed = { + "name", "label", "api_base_url", "api", "timeout_seconds", + "rpm_limit", "pool", "enabled", "status", "source", + "cooldown_until", "consecutive_429_count", "metadata", + } + for key, value in updates.items(): + if key in allowed: + setattr(current, key, value) + + current.updated_at = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + + # Handle API key update + api_key_encrypted = None + if "api_key_plain" in updates and updates["api_key_plain"]: + current.api_key_plain = updates["api_key_plain"] + api_key_encrypted = encrypt(updates["api_key_plain"]) + + # Handle model_mappings update + mappings_json = None + if "model_mappings" in updates: + current.model_mappings = updates["model_mappings"] + mappings_json = json.dumps(_mappings_to_dict(current.model_mappings)) + + with get_connection() as conn: + # Build dynamic UPDATE + set_clauses = [ + "name = ?", "label = ?", "api_base_url = ?", "api = ?", + "timeout_seconds = ?", "rpm_limit = ?", "pool = ?", "enabled = ?", + "status = ?", "source = ?", "cooldown_until = ?", + "consecutive_429_count = ?", "metadata_json = ?", "updated_at = ?", + ] + params = [ + current.name, current.label, current.api_base_url, current.api, + current.timeout_seconds, current.rpm_limit, current.pool, + 1 if current.enabled else 0, current.status, current.source, + current.cooldown_until, current.consecutive_429_count, + json.dumps(current.metadata), current.updated_at, + ] + if api_key_encrypted: + set_clauses.append("api_key_encrypted = ?") + params.append(api_key_encrypted) + if mappings_json is not None: + set_clauses.append("model_mappings_json = ?") + params.append(mappings_json) + params.append(backend_id) + + conn.execute( + f"UPDATE backends SET {', '.join(set_clauses)} WHERE id = ?", + params, + ) + conn.commit() + + return get_backend(backend_id, decrypt_key=False) + + +def delete_backend(backend_id: str) -> bool: + """Delete a backend. Returns True if deleted.""" + with get_connection() as conn: + cursor = conn.execute("DELETE FROM backends WHERE id = ?", (backend_id,)) + conn.commit() + return cursor.rowcount > 0 + + +def set_backend_status(backend_id: str, status: str) -> bool: + """Quickly set backend status (healthy/cooling/error/disabled).""" + now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + with get_connection() as conn: + cursor = conn.execute( + "UPDATE backends SET status = ?, updated_at = ? WHERE id = ?", + (status, now, backend_id), + ) + conn.commit() + return cursor.rowcount > 0 + + +def set_backend_cooldown(backend_id: str, cooldown_until: str, count: int) -> bool: + """Set cooldown state on a backend.""" + now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + with get_connection() as conn: + cursor = conn.execute( + """UPDATE backends SET status = 'cooling', cooldown_until = ?, + consecutive_429_count = ?, updated_at = ? WHERE id = ?""", + (cooldown_until, count, now, backend_id), + ) + conn.commit() + return cursor.rowcount > 0 + + +def clear_backend_cooldown(backend_id: str) -> bool: + """Clear cooldown (back to healthy).""" + now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + with get_connection() as conn: + cursor = conn.execute( + """UPDATE backends SET status = 'healthy', cooldown_until = NULL, + consecutive_429_count = 0, updated_at = ? WHERE id = ?""", + (now, backend_id), + ) + conn.commit() + return cursor.rowcount > 0 + + +def get_pool_stats() -> dict: + """Get summary stats per pool.""" + with get_connection() as conn: + rows = conn.execute( + """SELECT pool, COUNT(*) as total, + SUM(CASE WHEN enabled = 1 THEN 1 ELSE 0 END) as enabled, + SUM(CASE WHEN status = 'healthy' THEN 1 ELSE 0 END) as healthy, + SUM(CASE WHEN status = 'cooling' THEN 1 ELSE 0 END) as cooling, + SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as error + FROM backends GROUP BY pool""" + ).fetchall() + stats = {} + for row in rows: + stats[row["pool"]] = { + "total": row["total"], + "enabled": row["enabled"], + "healthy": row["healthy"], + "cooling": row["cooling"], + "error": row["error"], + } + return stats + + +def _row_to_backend(row, decrypt_key: bool = True) -> Backend: + """Convert a DB row to a Backend instance.""" + mappings_raw = row["model_mappings_json"] or "{}" + mappings_dict = json.loads(mappings_raw) + + model_mappings = {} + for canonical_name, mm in mappings_dict.items(): + model_mappings[canonical_name] = ModelMapping.from_dict(mm) + + backend = Backend( + id=row["id"], + name=row["name"], + label=row["label"], + api_base_url=row["api_base_url"], + api_key_encrypted=row["api_key_encrypted"] or "", + api=row["api"], + timeout_seconds=row["timeout_seconds"], + rpm_limit=row["rpm_limit"], + pool=row["pool"], + enabled=bool(row["enabled"]), + status=row["status"], + model_mappings=model_mappings, + source=row["source"], + cooldown_until=row["cooldown_until"], + consecutive_429_count=row["consecutive_429_count"], + metadata=json.loads(row["metadata_json"] or "{}"), + created_at=row["created_at"], + updated_at=row["updated_at"], + ) + + if decrypt_key and backend.api_key_encrypted: + from crypto import try_decrypt_existing + plain = try_decrypt_existing(backend.api_key_encrypted) + if plain: + backend.api_key_plain = plain + + return backend + + +def _mappings_to_dict(mappings: dict[str, ModelMapping]) -> dict: + """Convert ModelMapping dict to JSON-safe dict.""" + return {k: v.to_dict() for k, v in mappings.items()} \ No newline at end of file diff --git a/services/nvidia_sidecar/storage/config_store.py b/services/nvidia_sidecar/storage/config_store.py new file mode 100644 index 0000000..d71a398 --- /dev/null +++ b/services/nvidia_sidecar/storage/config_store.py @@ -0,0 +1,55 @@ +"""System configuration KV store operations.""" + +import time +from typing import Optional, Any + +from storage.db import get_connection + + +def get_config(key: str) -> Optional[str]: + """Get a single config value.""" + with get_connection() as conn: + row = conn.execute( + "SELECT value FROM system_config WHERE key = ?", (key,) + ).fetchone() + return row["value"] if row else None + + +def set_config(key: str, value: str, description: str = "") -> None: + """Set or update a config value.""" + now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + with get_connection() as conn: + conn.execute( + """INSERT INTO system_config (key, value, description, updated_at) + VALUES (?, ?, ?, ?) + ON CONFLICT(key) DO UPDATE SET + value = excluded.value, + description = excluded.description, + updated_at = excluded.updated_at""", + (key, value, description, now), + ) + conn.commit() + + +def delete_config(key: str) -> bool: + """Delete a config value.""" + with get_connection() as conn: + cursor = conn.execute( + "DELETE FROM system_config WHERE key = ?", (key,) + ) + conn.commit() + return cursor.rowcount > 0 + + +def list_configs() -> list[dict]: + """List all system config entries.""" + with get_connection() as conn: + rows = conn.execute("SELECT * FROM system_config ORDER BY key").fetchall() + return [dict(row) for row in rows] + + +def get_all_configs_as_dict() -> dict[str, str]: + """Get all configs as a simple dict.""" + with get_connection() as conn: + rows = conn.execute("SELECT key, value FROM system_config").fetchall() + return {row["key"]: row["value"] for row in rows} \ No newline at end of file diff --git a/services/nvidia_sidecar/storage/cooldown_store.py b/services/nvidia_sidecar/storage/cooldown_store.py new file mode 100644 index 0000000..1756913 --- /dev/null +++ b/services/nvidia_sidecar/storage/cooldown_store.py @@ -0,0 +1,74 @@ +"""Cooldown event logging.""" + +import time +from typing import Optional + +from storage.db import get_connection, generate_id +from storage.models import CooldownEvent + + +def log_cooldown_event( + backend_id: str, + consecutive_count: int, + cooldown_seconds: int, + response_summary: str = "", +) -> CooldownEvent: + """Record a cooldown event.""" + event = CooldownEvent( + id=generate_id("cev"), + backend_id=backend_id, + consecutive_count=consecutive_count, + cooldown_seconds=cooldown_seconds, + response_summary=response_summary, + started_at=time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + ) + + with get_connection() as conn: + conn.execute( + """INSERT INTO cooldown_events + (id, backend_id, consecutive_count, cooldown_seconds, + response_summary, started_at) + VALUES (?, ?, ?, ?, ?, ?)""", + (event.id, event.backend_id, event.consecutive_count, + event.cooldown_seconds, event.response_summary, event.started_at), + ) + conn.commit() + + return event + + +def end_cooldown_event(backend_id: str) -> bool: + """Mark the latest open cooldown event as ended.""" + ended_at = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + with get_connection() as conn: + # Find the latest event for this backend that hasn't ended + cursor = conn.execute( + """UPDATE cooldown_events SET ended_at = ? + WHERE backend_id = ? AND ended_at IS NULL + ORDER BY started_at DESC LIMIT 1""", + (ended_at, backend_id), + ) + conn.commit() + return cursor.rowcount > 0 + + +def get_cooldown_history( + backend_id: Optional[str] = None, + limit: int = 50, +) -> list[dict]: + """Get cooldown event history.""" + with get_connection() as conn: + if backend_id: + rows = conn.execute( + """SELECT * FROM cooldown_events + WHERE backend_id = ? + ORDER BY started_at DESC LIMIT ?""", + (backend_id, limit), + ).fetchall() + else: + rows = conn.execute( + """SELECT * FROM cooldown_events + ORDER BY started_at DESC LIMIT ?""", + (limit,), + ).fetchall() + return [dict(row) for row in rows] \ No newline at end of file diff --git a/services/nvidia_sidecar/storage/db.py b/services/nvidia_sidecar/storage/db.py new file mode 100644 index 0000000..e550b55 --- /dev/null +++ b/services/nvidia_sidecar/storage/db.py @@ -0,0 +1,193 @@ +"""SQLite database connection management with WAL mode.""" + +import os +import sqlite3 +import uuid +import structlog +from contextlib import contextmanager +from typing import Generator + +from config import config + +logger = structlog.get_logger() + +# Module-level DB path +_DB_PATH: str = "" + + +def init_db(db_path: str = "") -> None: + """Initialize the database connection and ensure WAL mode. + + Creates the data directory if needed and verifies integrity. + """ + global _DB_PATH + _DB_PATH = db_path or config.db_path + + # Ensure data directory exists + os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True) + + # Test connection and enable WAL + conn = _get_raw_connection() + try: + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA wal_autocheckpoint=1000") + conn.execute("PRAGMA foreign_keys=ON") + conn.execute("PRAGMA busy_timeout=5000") + logger.info("db_initialized", path=_DB_PATH, mode="WAL") + finally: + conn.close() + + +def _get_raw_connection() -> sqlite3.Connection: + """Get a raw sqlite3 connection.""" + conn = sqlite3.connect(_DB_PATH, check_same_thread=False) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA foreign_keys=ON") + return conn + + +@contextmanager +def get_connection() -> Generator[sqlite3.Connection, None, None]: + """Get a database connection with WAL enabled.""" + conn = _get_raw_connection() + try: + yield conn + finally: + conn.close() + + +def generate_id(prefix: str = "") -> str: + """Generate a unique ID with optional prefix.""" + uid = uuid.uuid4().hex[:12] + return f"{prefix}_{uid}" if prefix else uid + + +def create_tables() -> None: + """Create all tables if they don't exist.""" + with get_connection() as conn: + conn.executescript(_DDL) + conn.commit() + logger.info("tables_created") + + +def run_integrity_check() -> bool: + """Run PRAGMA integrity_check and return True if OK.""" + with get_connection() as conn: + result = conn.execute("PRAGMA integrity_check").fetchone() + ok = result[0] == "ok" + if not ok: + logger.error("integrity_check_failed", result=result[0]) + return ok + + +def get_db_sizes() -> dict: + """Get database and WAL file sizes.""" + result = {"db_bytes": 0, "wal_bytes": 0} + db_path = _DB_PATH + if os.path.exists(db_path): + result["db_bytes"] = os.path.getsize(db_path) + wal_path = db_path + "-wal" + if os.path.exists(wal_path): + result["wal_bytes"] = os.path.getsize(wal_path) + return result + + +def wal_checkpoint(mode: str = "TRUNCATE") -> None: + """Execute WAL checkpoint.""" + with get_connection() as conn: + conn.execute(f"PRAGMA wal_checkpoint({mode})") + + +_DDL = """ +-- Backend configuration table (core) +CREATE TABLE IF NOT EXISTS backends ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + label TEXT DEFAULT '', + api_base_url TEXT NOT NULL, + api_key_encrypted TEXT NOT NULL, + api TEXT NOT NULL DEFAULT 'openai-completions', + timeout_seconds INTEGER NOT NULL DEFAULT 120, + rpm_limit INTEGER NOT NULL DEFAULT 40, + pool TEXT NOT NULL DEFAULT 'primary' + CHECK(pool IN ('primary', 'fallback')), + enabled INTEGER NOT NULL DEFAULT 1, + status TEXT NOT NULL DEFAULT 'healthy' + CHECK(status IN ('healthy', 'cooling', 'error', 'disabled')), + model_mappings_json TEXT DEFAULT '{}', + source TEXT NOT NULL DEFAULT 'webui' + CHECK(source IN ('webui', 'env', 'import')), + cooldown_until TEXT, + consecutive_429_count INTEGER DEFAULT 0, + metadata_json TEXT DEFAULT '{}', + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +-- Usage logs (hour-bucketed, UPSERT-safe) +CREATE TABLE IF NOT EXISTS backend_usage_logs ( + id TEXT PRIMARY KEY, + backend_id TEXT NOT NULL REFERENCES backends(id) ON DELETE CASCADE, + model TEXT DEFAULT 'unknown', + prompt_tokens INTEGER DEFAULT 0, + completion_tokens INTEGER DEFAULT 0, + total_tokens INTEGER DEFAULT 0, + cost REAL DEFAULT 0.0, + request_count INTEGER DEFAULT 0, + error_count INTEGER DEFAULT 0, + avg_latency_ms INTEGER DEFAULT 0, + ttft_ms INTEGER DEFAULT 0, + hour_bucket TEXT NOT NULL, + created_at TEXT NOT NULL DEFAULT (datetime('now')) +); +CREATE UNIQUE INDEX IF NOT EXISTS idx_usage_backend_hour + ON backend_usage_logs(backend_id, hour_bucket); + +-- Cooldown event log +CREATE TABLE IF NOT EXISTS cooldown_events ( + id TEXT PRIMARY KEY, + backend_id TEXT NOT NULL REFERENCES backends(id) ON DELETE CASCADE, + consecutive_count INTEGER NOT NULL DEFAULT 1, + cooldown_seconds INTEGER NOT NULL, + response_summary TEXT DEFAULT '', + started_at TEXT NOT NULL DEFAULT (datetime('now')), + ended_at TEXT +); +CREATE INDEX IF NOT EXISTS idx_cooldown_backend_time + ON cooldown_events(backend_id, started_at); + +-- Backend health state +CREATE TABLE IF NOT EXISTS backend_health ( + backend_id TEXT PRIMARY KEY REFERENCES backends(id) ON DELETE CASCADE, + state TEXT NOT NULL DEFAULT 'healthy' + CHECK(state IN ('healthy', 'degraded', 'down')), + last_latency_ms INTEGER DEFAULT 0, + last_status_code INTEGER DEFAULT 200, + success_rate_5m REAL DEFAULT 1.0, + consecutive_failures INTEGER DEFAULT 0, + last_check_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +-- System configuration KV store +CREATE TABLE IF NOT EXISTS system_config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL, + description TEXT DEFAULT '', + updated_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +-- Daily aggregated stats +CREATE TABLE IF NOT EXISTS daily_stats ( + id TEXT PRIMARY KEY, + date TEXT NOT NULL, + pool TEXT NOT NULL CHECK(pool IN ('primary', 'fallback')), + total_requests INTEGER DEFAULT 0, + total_errors INTEGER DEFAULT 0, + total_tokens INTEGER DEFAULT 0, + total_cost REAL DEFAULT 0.0, + unique_backends INTEGER DEFAULT 0, + created_at TEXT NOT NULL DEFAULT (datetime('now')) +); +CREATE UNIQUE INDEX IF NOT EXISTS idx_daily_date_pool ON daily_stats(date, pool); +""" \ No newline at end of file diff --git a/services/nvidia_sidecar/storage/models.py b/services/nvidia_sidecar/storage/models.py new file mode 100644 index 0000000..e841897 --- /dev/null +++ b/services/nvidia_sidecar/storage/models.py @@ -0,0 +1,161 @@ +"""Data models for Sidecar V2 — backend-centric, Canonical Name routing.""" + +from dataclasses import dataclass, field, asdict +from typing import Optional +import json + + +@dataclass +class ModelMapping: + """A single model mapping within a backend: Canonical Name → native_id + properties.""" + + native_id: str + reasoning: bool = False + reasoning_effort: bool = False + input_modalities: list[str] = field(default_factory=lambda: ["text"]) + cost: dict = field(default_factory=lambda: { + "input": 0.0, "output": 0.0, "cacheRead": 0.0, "cacheWrite": 0.0 + }) + context_window: int = 128000 + max_tokens: int = 65536 + compat: dict = field(default_factory=dict) + + def to_dict(self) -> dict: + return asdict(self) + + @classmethod + def from_dict(cls, d: dict) -> "ModelMapping": + defaults = { + "native_id": "", + "reasoning": False, + "reasoning_effort": False, + "input_modalities": ["text"], + "cost": {"input": 0.0, "output": 0.0, "cacheRead": 0.0, "cacheWrite": 0.0}, + "context_window": 128000, + "max_tokens": 65536, + "compat": {}, + } + defaults.update(d) + return cls(**{k: v for k, v in defaults.items() if k in cls.__dataclass_fields__}) + + +@dataclass +class Backend: + """A physical API backend (API Key + URL). + + Represents a single API key endpoint. Multiple backends can serve the same + Canonical Models through their model_mappings. + """ + + id: str = "" + name: str = "" + label: str = "" # e.g., "nvidia", "siliconflow" — WebUI tag only + api_base_url: str = "" + api_key_encrypted: str = "" + api: str = "openai-completions" + timeout_seconds: int = 120 + rpm_limit: int = 40 + pool: str = "primary" # primary | fallback + enabled: bool = True + status: str = "healthy" # healthy | cooling | error | disabled + model_mappings: dict[str, ModelMapping] = field(default_factory=dict) + source: str = "webui" # webui | env | import + cooldown_until: Optional[str] = None + consecutive_429_count: int = 0 + metadata: dict = field(default_factory=dict) + created_at: str = "" + updated_at: str = "" + + # Runtime fields (not persisted) + api_key_plain: str = "" # decrypted at load time, not serialized to DB + + def has_model(self, canonical_name: str) -> bool: + """Check if backend supports a given Canonical Model.""" + return canonical_name in self.model_mappings + + def get_native_id(self, canonical_name: str) -> str: + """Get this backend's native model ID for a Canonical Name.""" + mm = self.model_mappings.get(canonical_name) + return mm.native_id if mm else canonical_name + + def get_model_cost(self, canonical_name: str) -> dict: + """Get cost info for a Canonical Model on this backend.""" + mm = self.model_mappings.get(canonical_name) + return mm.cost if mm else {"input": 0.0, "output": 0.0, "cacheRead": 0.0, "cacheWrite": 0.0} + + def to_dict(self, mask_key: bool = True) -> dict: + """Convert to dict for API responses.""" + d = asdict(self) + # Remove runtime-only fields + d.pop("api_key_plain", None) + d.pop("api_key_encrypted", None) + + # Mask API key + if mask_key and self.api_key_plain: + d["api_key"] = _mask_key(self.api_key_plain) + elif self.api_key_plain: + d["api_key"] = self.api_key_plain + else: + d["api_key"] = "" + + # Convert model_mappings to dict for serialization + d["model_mappings"] = { + k: v.to_dict() for k, v in self.model_mappings.items() + } + return d + + +def _mask_key(key: str) -> str: + if len(key) <= 10: + return key[:2] + "****" + return key[:6] + "****" + key[-4:] + + +@dataclass +class CooldownEvent: + id: str = "" + backend_id: str = "" + consecutive_count: int = 1 + cooldown_seconds: int = 60 + response_summary: str = "" + started_at: str = "" + ended_at: Optional[str] = None + + +@dataclass +class BackendHealth: + backend_id: str = "" + state: str = "healthy" # healthy | degraded | down + last_latency_ms: int = 0 + last_status_code: int = 200 + success_rate_5m: float = 1.0 + consecutive_failures: int = 0 + last_check_at: str = "" + + +@dataclass +class UsageLog: + id: str = "" + backend_id: str = "" + model: str = "unknown" + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + cost: float = 0.0 + request_count: int = 0 + error_count: int = 0 + avg_latency_ms: int = 0 + ttft_ms: int = 0 + hour_bucket: str = "" + + +@dataclass +class DailyStats: + id: str = "" + date: str = "" + pool: str = "primary" + total_requests: int = 0 + total_errors: int = 0 + total_tokens: int = 0 + total_cost: float = 0.0 + unique_backends: int = 0 \ No newline at end of file diff --git a/services/nvidia_sidecar/storage/usage_store.py b/services/nvidia_sidecar/storage/usage_store.py new file mode 100644 index 0000000..a8cb6ab --- /dev/null +++ b/services/nvidia_sidecar/storage/usage_store.py @@ -0,0 +1,155 @@ +"""Usage logging and daily statistics aggregation.""" + +import time +from typing import Optional + +from storage.db import get_connection, generate_id + + +def record_usage( + backend_id: str, + model: str, + prompt_tokens: int, + completion_tokens: int, + cost: float, + latency_ms: int, + ttft_ms: int = 0, + is_error: bool = False, +) -> None: + """Record a single request's usage, hour-bucketed with UPSERT.""" + hour_bucket = time.strftime("%Y-%m-%dT%H:00:00Z", time.gmtime()) + uid = generate_id("use") + + with get_connection() as conn: + # Try update existing hour bucket + cursor = conn.execute( + """UPDATE backend_usage_logs SET + prompt_tokens = prompt_tokens + ?, + completion_tokens = completion_tokens + ?, + total_tokens = total_tokens + ?, + cost = cost + ?, + request_count = request_count + 1, + error_count = error_count + ?, + avg_latency_ms = CAST((avg_latency_ms * request_count + ?) / (request_count + 1) AS INTEGER), + ttft_ms = CASE WHEN ? > 0 THEN CAST((ttft_ms * request_count + ?) / (request_count + 1) AS INTEGER) ELSE ttft_ms END + WHERE backend_id = ? AND hour_bucket = ?""", + ( + prompt_tokens, completion_tokens, + prompt_tokens + completion_tokens, + cost, + 1 if is_error else 0, + latency_ms, + ttft_ms, ttft_ms, + backend_id, hour_bucket, + ), + ) + if cursor.rowcount == 0: + # Insert new hour bucket + conn.execute( + """INSERT INTO backend_usage_logs + (id, backend_id, model, prompt_tokens, completion_tokens, + total_tokens, cost, request_count, error_count, + avg_latency_ms, ttft_ms, hour_bucket) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + uid, backend_id, model, + prompt_tokens, completion_tokens, + prompt_tokens + completion_tokens, + cost, 1, 1 if is_error else 0, + latency_ms, ttft_ms, hour_bucket, + ), + ) + conn.commit() + + +def get_hourly_usage( + backend_id: Optional[str] = None, + since: Optional[str] = None, + limit: int = 168, +) -> list[dict]: + """Get hourly usage data, optionally filtered by backend and time range.""" + with get_connection() as conn: + if backend_id and since: + rows = conn.execute( + """SELECT * FROM backend_usage_logs + WHERE backend_id = ? AND hour_bucket >= ? + ORDER BY hour_bucket DESC LIMIT ?""", + (backend_id, since, limit), + ).fetchall() + elif backend_id: + rows = conn.execute( + """SELECT * FROM backend_usage_logs + WHERE backend_id = ? ORDER BY hour_bucket DESC LIMIT ?""", + (backend_id, limit), + ).fetchall() + elif since: + rows = conn.execute( + """SELECT * FROM backend_usage_logs + WHERE hour_bucket >= ? ORDER BY hour_bucket DESC LIMIT ?""", + (since, limit), + ).fetchall() + else: + rows = conn.execute( + """SELECT * FROM backend_usage_logs + ORDER BY hour_bucket DESC LIMIT ?""", + (limit,), + ).fetchall() + return [dict(row) for row in rows] + + +def get_total_stats() -> dict: + """Get aggregate stats across all backends.""" + with get_connection() as conn: + row = conn.execute( + """SELECT + SUM(request_count) as total_requests, + SUM(error_count) as total_errors, + SUM(total_tokens) as total_tokens, + SUM(prompt_tokens) as total_prompt_tokens, + SUM(completion_tokens) as total_completion_tokens, + SUM(cost) as total_cost + FROM backend_usage_logs""" + ).fetchone() + if row is None: + return { + "total_requests": 0, "total_errors": 0, + "total_tokens": 0, "total_prompt_tokens": 0, + "total_completion_tokens": 0, "total_cost": 0.0, + } + return dict(row) + + +def aggregate_daily_stats(date: str) -> None: + """Aggregate hourly usage into daily stats table.""" + with get_connection() as conn: + # Aggregate per pool + conn.execute("""DELETE FROM daily_stats WHERE date = ?""", (date,)) + conn.execute( + """INSERT INTO daily_stats (id, date, pool, total_requests, + total_errors, total_tokens, total_cost, unique_backends) + SELECT + ? || '-' || b.pool, + ?, + b.pool, + SUM(u.request_count), + SUM(u.error_count), + SUM(u.total_tokens), + SUM(u.cost), + COUNT(DISTINCT u.backend_id) + FROM backend_usage_logs u + JOIN backends b ON u.backend_id = b.id + WHERE u.hour_bucket LIKE ? + GROUP BY b.pool""", + (generate_id("day"), date, date + "%"), + ) + conn.commit() + + +def get_daily_stats(days: int = 30) -> list[dict]: + """Get daily aggregated stats.""" + with get_connection() as conn: + rows = conn.execute( + """SELECT * FROM daily_stats ORDER BY date DESC LIMIT ?""", + (days,), + ).fetchall() + return [dict(row) for row in rows] \ No newline at end of file