Files
EnterpriseArchitect/services/nvidia_sidecar/proxy.py
T
vincent 4f415fb500 fix(sidecar-v2): incorporate review feedback - P0/P1 fixes
P0 fixes:
- Admin API Bearer Token auth middleware
- Encryption key missing -> CRITICAL log + sys.exit(1)
- Prometheus metrics endpoint (:9191)
- requirements.txt + Dockerfile + docker-compose.yml + systemd + nginx

P1 fixes:
- Dead code removed from _refresh_cooldowns()
- Stream detection fixed (text/event-stream only)
- Emergency passthrough (10% RPM retry before 503)
- Active health probing for backends
- SQLite daily backup loop with retention
- Chart.js CDN fallback
- Key rotation SOP document
- JSON log format support
- Deploy files: systemd unit + nginx config

BIZ-52 review re-entry

Co-authored-by: multica-agent <github@multica.ai>
2026-06-25 17:12:33 +08:00

372 lines
12 KiB
Python

"""Proxy request handling for Sidecar V2 — multi-pool routing + cooldown + rate limiting."""
import asyncio
import json
import time
from typing import Any, Optional
import httpx
import structlog
from fastapi import Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from config import config
from pool_manager import PoolManager
from rate_limiter import PerBackendRateLimiter
from router import Router
from cooldown_manager import start_cooldown, check_and_clear_cooldown
from storage.models import Backend
from storage.usage_store import record_usage
logger: structlog.stdlib.BoundLogger = structlog.get_logger("sidecar_v2.proxy")
def extract_model(body: dict[str, Any]) -> str:
"""Extract model identifier from request body."""
return str(body.get("model", "unknown"))
def build_error_response(status: int, message: str, error_type: str = "") -> JSONResponse:
"""Build a standard error response."""
return JSONResponse(
status_code=status,
content={
"error": {
"message": message,
"type": error_type or f"Error_{status}",
}
},
)
async def forward_to_backend(
backend: Backend,
method: str,
path: str,
body: bytes | None,
headers: dict[str, str],
stream: bool = False,
) -> httpx.Response:
"""Forward a request to a specific backend."""
upstream_url = backend.api_base_url.rstrip("/") + path
forward_headers = {
k: v
for k, v in headers.items()
if k.lower() not in ("host", "content-length", "transfer-encoding")
}
if backend.api_key_plain:
forward_headers["authorization"] = f"Bearer {backend.api_key_plain}"
elif "authorization" not in {k.lower() for k in forward_headers}:
forward_headers["authorization"] = "Bearer nvidia"
timeout = httpx.Timeout(backend.timeout_seconds)
async with httpx.AsyncClient(timeout=timeout) as client:
req = client.build_request(
method=method,
url=upstream_url,
headers=forward_headers,
content=body,
)
return await client.send(req, stream=stream)
def build_response(resp: httpx.Response) -> Response:
"""Convert httpx.Response to FastAPI Response."""
content_type = resp.headers.get("content-type", "")
headers = {
k: v
for k, v in resp.headers.items()
if k.lower() not in ("content-encoding", "transfer-encoding")
}
is_sse = "text/event-stream" in content_type
is_chunked = resp.headers.get("transfer-encoding", "").lower() == "chunked"
if is_sse or (is_chunked and headers.get("content-type", "") != "application/octet-stream"):
return StreamingResponse(
content=resp.aiter_bytes(),
status_code=resp.status_code,
headers=headers,
media_type=content_type or "text/event-stream",
)
return Response(
content=resp.content,
status_code=resp.status_code,
headers=headers,
media_type=content_type or "application/json",
)
def extract_usage_from_response(
resp: httpx.Response,
resp_json: dict[str, Any],
model: str,
) -> tuple[int, int, int]:
"""Extract token usage from response body (OpenAI-compatible)."""
usage = resp_json.get("usage", {})
prompt_tokens = usage.get("prompt_tokens", 0) or 0
completion_tokens = usage.get("completion_tokens", 0) or 0
# Try streaming chunks: aggregate from choices
if not prompt_tokens and not completion_tokens:
choices = resp_json.get("choices", [])
for choice in choices:
if isinstance(choice, dict):
tokens = choice.get("usage", {})
prompt_tokens += tokens.get("prompt_tokens", 0) or 0
completion_tokens += tokens.get("completion_tokens", 0) or 0
total_tokens = prompt_tokens + completion_tokens
if total_tokens == 0:
total_tokens = usage.get("total_tokens", 0) or 0
return prompt_tokens, completion_tokens, total_tokens
def calculate_cost(
backend: Backend,
model: str,
prompt_tokens: int,
completion_tokens: int,
) -> float:
"""Calculate cost using backend's model pricing."""
cost_info = backend.get_model_cost(model)
input_cost = cost_info.get("input", 0.0)
output_cost = cost_info.get("output", 0.0)
# Costs are per token
return (prompt_tokens * input_cost + completion_tokens * output_cost)
async def handle_proxy_request(
pool_manager: PoolManager,
rate_limiter: PerBackendRateLimiter,
router: Router,
request: Request,
path: str,
) -> Response:
"""Main proxy handler: multi-pool routing with cooldown and rate limiting.
Flow:
1. Extract model → canonical name
2. Pick backend via Router (primary → fallback)
3. Forward request
4. If 429 → cooldown backend, retry with another
5. If all pools exhausted → emergency mode
6. Track usage
"""
start_time = time.monotonic()
body_bytes: bytes = await request.body()
raw_headers: dict[str, str] = dict(request.headers)
body_json: dict[str, Any] = {}
try:
if body_bytes:
parsed = json.loads(body_bytes)
if isinstance(parsed, dict):
body_json = parsed
except (ValueError, TypeError):
body_json = {}
canonical_model = extract_model(body_json)
is_stream = body_json.get("stream", False)
# Try with pool routing
max_retries = config.max_pool_retries
for attempt in range(max_retries):
# Check and clear expired cooldowns before picking
_refresh_cooldowns()
backend = router.pick_backend(canonical_model)
if backend is None:
break # No backend available, fall through to emergency
try:
resp = await forward_to_backend(
backend=backend,
method=request.method,
path=path,
body=body_bytes if body_bytes else None,
headers=raw_headers,
stream=is_stream,
)
elapsed_ms = int((time.monotonic() - start_time) * 1000)
# Handle 429 — cooldown and retry
if resp.status_code == 429:
new_count = backend.consecutive_429_count + 1
start_cooldown(backend.id, new_count)
resp_body = ""
try:
resp_body = resp.text[:200]
except Exception:
pass
logger.warning(
"backend_429_cooldown",
backend_id=backend.id,
pool=backend.pool,
consecutive=new_count,
model=canonical_model,
)
# Track the error
record_usage(
backend_id=backend.id,
model=canonical_model,
prompt_tokens=0,
completion_tokens=0,
cost=0.0,
latency_ms=elapsed_ms,
is_error=True,
)
continue # Retry with another backend
# Success — track usage
resp_json: dict[str, Any] = {}
try:
if not is_stream and resp.content:
resp_json = json.loads(resp.content)
except (ValueError, TypeError):
pass
prompt_tokens, completion_tokens, total_tokens = extract_usage_from_response(
resp, resp_json, canonical_model
)
cost = calculate_cost(
backend, canonical_model, prompt_tokens, completion_tokens
)
record_usage(
backend_id=backend.id,
model=canonical_model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cost=cost,
latency_ms=elapsed_ms,
)
logger.info(
"request_completed",
backend_id=backend.id,
pool=backend.pool,
model=canonical_model,
status=resp.status_code,
tokens=total_tokens,
cost=round(cost, 6),
elapsed_ms=elapsed_ms,
)
return build_response(resp)
except httpx.TimeoutException:
logger.warning(
"backend_timeout",
backend_id=backend.id,
model=canonical_model,
)
continue
except (httpx.ConnectError, httpx.RemoteProtocolError) as exc:
logger.warning(
"backend_connection_error",
backend_id=backend.id,
model=canonical_model,
error=str(exc),
)
continue
except Exception as exc:
logger.error(
"proxy_error",
backend_id=backend.id,
model=canonical_model,
error=str(exc),
)
continue
# All pools exhausted — emergency rate-limited passthrough
emergency_rpm = int(config.default_rpm_limit * config.emergency_rpm_fraction)
if emergency_rpm < 1:
emergency_rpm = 1
logger.warning(
"all_pools_exhausted_emergency",
model=canonical_model,
emergency_rpm=emergency_rpm,
)
# Emergency: try to get a token from any fallback backend at reduced RPM
emergency_retries = 3
for attempt in range(emergency_retries):
backends = pool_manager.get_any_healthy_backends()
for backend in backends:
if rate_limiter.consume(backend.id, emergency_rpm):
try:
resp = await forward_to_backend(
backend=backend,
method=request.method,
path=path,
body=body_bytes if body_bytes else None,
headers=raw_headers,
stream=is_stream,
)
elapsed_ms = int((time.monotonic() - start_time) * 1000)
if resp.status_code == 429:
start_cooldown(backend.id, backend.consecutive_429_count + 1)
continue
# Success in emergency mode
try:
resp_json: dict[str, Any] = {}
if not is_stream and resp.content:
resp_json = json.loads(resp.content)
except Exception:
resp_json = {}
prompt_tokens, completion_tokens, total_tokens = extract_usage_from_response(
resp, resp_json, canonical_model
)
cost_em = calculate_cost(backend, canonical_model, prompt_tokens, completion_tokens)
record_usage(
backend_id=backend.id,
model=canonical_model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cost=cost_em,
latency_ms=elapsed_ms,
)
logger.info(
"emergency_passthrough_success",
backend_id=backend.id,
model=canonical_model,
emergency_rpm=emergency_rpm,
)
return build_response(resp)
except Exception:
continue
# All emergency attempts failed — return 503 for OpenClaw fallback chain
return build_error_response(
503,
"All provider pools exhausted. OpenClaw fallback chain should activate.",
"AllPoolsExhausted",
)
def _refresh_cooldowns() -> None:
"""Check and clear expired cooldowns for backends currently in cooling state.
Only queries backends with status='cooling' (the health_check_loop handles
the periodic scanning; this is the on-demand refresh before proxy routing)."""
from storage.backend_store import list_backends
backends = list_backends(decrypt_key=False)
for backend in backends:
if backend.status == "cooling":
check_and_clear_cooldown(backend.id)