376ce97d91
When all primary backends are in cooldown, wait and retry the primary pool before falling through to fallback/emergency. This reduces unnecessary spend on paid fallback providers during temporary 429 storms. Config: - primary_wait_ms (default 5000, env SIDECAR_PRIMARY_WAIT_MS) - primary_wait_max_retries (default 6, env SIDECAR_PRIMARY_WAIT_MAX_RETRIES) Implementation: - config.py: 2 new config fields + env var loading - router.py: pick_primary_backend() — primary-pool-only selection - proxy.py: primary-wait loop between standard retries and emergency Expected win: 17% error rate during high concurrency drops, emergency passthrough count falls as requests wait for NVIDIA pool recovery instead of immediately routing to SiliconFlow fallback.
504 lines
17 KiB
Python
504 lines
17 KiB
Python
"""Proxy request handling for Sidecar V2 — multi-pool routing + cooldown + rate limiting."""
|
|
|
|
import asyncio
|
|
import json
|
|
import time
|
|
from typing import Any, Optional
|
|
|
|
import httpx
|
|
import structlog
|
|
from fastapi import Request
|
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
|
|
|
from config import config
|
|
from pool_manager import PoolManager
|
|
from rate_limiter import PerBackendRateLimiter
|
|
from router import Router
|
|
from cooldown_manager import start_cooldown, check_and_clear_cooldown
|
|
from storage.models import Backend
|
|
from storage.usage_store import record_usage
|
|
|
|
# Emergency activation counter (read by metrics endpoint)
|
|
_emergency_count: int = 0
|
|
|
|
|
|
def get_emergency_count() -> int:
|
|
return _emergency_count
|
|
|
|
|
|
logger: structlog.stdlib.BoundLogger = structlog.get_logger("sidecar_v2.proxy")
|
|
|
|
|
|
def extract_model(body: dict[str, Any]) -> str:
|
|
"""Extract model identifier from request body."""
|
|
return str(body.get("model", "unknown"))
|
|
|
|
|
|
def build_error_response(status: int, message: str, error_type: str = "") -> JSONResponse:
|
|
"""Build a standard error response."""
|
|
return JSONResponse(
|
|
status_code=status,
|
|
content={
|
|
"error": {
|
|
"message": message,
|
|
"type": error_type or f"Error_{status}",
|
|
}
|
|
},
|
|
)
|
|
|
|
|
|
async def forward_to_backend(
|
|
backend: Backend,
|
|
method: str,
|
|
path: str,
|
|
body: bytes | None,
|
|
headers: dict[str, str],
|
|
stream: bool = False,
|
|
) -> httpx.Response:
|
|
"""Forward a request to a specific backend."""
|
|
# Strip API version prefix from path when base URL already has one
|
|
base = backend.api_base_url.rstrip("/")
|
|
# Check if the last path segment of base matches the first segment of path
|
|
base_parts = base.split("/")
|
|
path_parts = path.split("/")
|
|
if len(base_parts) >= 2 and len(path_parts) >= 2:
|
|
base_last = base_parts[-1]
|
|
path_first = path_parts[1] if len(path_parts) > 1 and path_parts[0] == "" else path_parts[0]
|
|
if base_last == path_first:
|
|
# Remove duplicate prefix from path
|
|
path = "/" + "/".join(path_parts[2:])
|
|
upstream_url = base + path
|
|
|
|
forward_headers = {
|
|
k: v
|
|
for k, v in headers.items()
|
|
if k.lower() not in ("host", "content-length", "transfer-encoding")
|
|
}
|
|
|
|
if backend.api_key_plain:
|
|
forward_headers["authorization"] = f"Bearer {backend.api_key_plain}"
|
|
elif "authorization" not in {k.lower() for k in forward_headers}:
|
|
forward_headers["authorization"] = "Bearer nvidia"
|
|
|
|
timeout = httpx.Timeout(backend.timeout_seconds)
|
|
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
req = client.build_request(
|
|
method=method,
|
|
url=upstream_url,
|
|
headers=forward_headers,
|
|
content=body,
|
|
)
|
|
return await client.send(req, stream=stream)
|
|
|
|
|
|
def build_response(resp: httpx.Response) -> Response:
|
|
"""Convert httpx.Response to FastAPI Response."""
|
|
content_type = resp.headers.get("content-type", "")
|
|
headers = {
|
|
k: v
|
|
for k, v in resp.headers.items()
|
|
if k.lower() not in ("content-encoding", "transfer-encoding")
|
|
}
|
|
|
|
is_sse = "text/event-stream" in content_type
|
|
is_chunked = resp.headers.get("transfer-encoding", "").lower() == "chunked"
|
|
if is_sse or (is_chunked and headers.get("content-type", "") != "application/octet-stream"):
|
|
return StreamingResponse(
|
|
content=resp.aiter_bytes(),
|
|
status_code=resp.status_code,
|
|
headers=headers,
|
|
media_type=content_type or "text/event-stream",
|
|
)
|
|
|
|
return Response(
|
|
content=resp.content,
|
|
status_code=resp.status_code,
|
|
headers=headers,
|
|
media_type=content_type or "application/json",
|
|
)
|
|
|
|
|
|
def extract_usage_from_response(
|
|
resp: httpx.Response,
|
|
resp_json: dict[str, Any],
|
|
model: str,
|
|
) -> tuple[int, int, int]:
|
|
"""Extract token usage from response body (OpenAI-compatible)."""
|
|
usage = resp_json.get("usage", {})
|
|
prompt_tokens = usage.get("prompt_tokens", 0) or 0
|
|
completion_tokens = usage.get("completion_tokens", 0) or 0
|
|
|
|
# Try streaming chunks: aggregate from choices
|
|
if not prompt_tokens and not completion_tokens:
|
|
choices = resp_json.get("choices", [])
|
|
for choice in choices:
|
|
if isinstance(choice, dict):
|
|
tokens = choice.get("usage", {})
|
|
prompt_tokens += tokens.get("prompt_tokens", 0) or 0
|
|
completion_tokens += tokens.get("completion_tokens", 0) or 0
|
|
|
|
total_tokens = prompt_tokens + completion_tokens
|
|
if total_tokens == 0:
|
|
total_tokens = usage.get("total_tokens", 0) or 0
|
|
|
|
return prompt_tokens, completion_tokens, total_tokens
|
|
|
|
|
|
def calculate_cost(
|
|
backend: Backend,
|
|
model: str,
|
|
prompt_tokens: int,
|
|
completion_tokens: int,
|
|
) -> float:
|
|
"""Calculate cost using backend's model pricing."""
|
|
cost_info = backend.get_model_cost(model)
|
|
input_cost = cost_info.get("input", 0.0)
|
|
output_cost = cost_info.get("output", 0.0)
|
|
# Costs are per token
|
|
return (prompt_tokens * input_cost + completion_tokens * output_cost)
|
|
|
|
|
|
async def handle_proxy_request(
|
|
pool_manager: PoolManager,
|
|
rate_limiter: PerBackendRateLimiter,
|
|
router: Router,
|
|
request: Request,
|
|
path: str,
|
|
) -> Response:
|
|
"""Main proxy handler: multi-pool routing with cooldown and rate limiting.
|
|
|
|
Flow:
|
|
1. Extract model → canonical name
|
|
2. Pick backend via Router (primary → fallback)
|
|
3. Forward request
|
|
4. If 429 → cooldown backend, retry with another
|
|
5. If all pools exhausted → emergency mode
|
|
6. Track usage
|
|
"""
|
|
global _emergency_count
|
|
start_time = time.monotonic()
|
|
|
|
body_bytes: bytes = await request.body()
|
|
raw_headers: dict[str, str] = dict(request.headers)
|
|
|
|
body_json: dict[str, Any] = {}
|
|
try:
|
|
if body_bytes:
|
|
parsed = json.loads(body_bytes)
|
|
if isinstance(parsed, dict):
|
|
body_json = parsed
|
|
except (ValueError, TypeError):
|
|
body_json = {}
|
|
|
|
canonical_model = extract_model(body_json)
|
|
is_stream = body_json.get("stream", False)
|
|
|
|
# Try with pool routing
|
|
max_retries = config.max_pool_retries
|
|
for attempt in range(max_retries):
|
|
# Check and clear expired cooldowns before picking
|
|
_refresh_cooldowns()
|
|
|
|
backend = router.pick_backend(canonical_model)
|
|
if backend is None:
|
|
break # No backend available, fall through to emergency
|
|
|
|
try:
|
|
resp = await forward_to_backend(
|
|
backend=backend,
|
|
method=request.method,
|
|
path=path,
|
|
body=body_bytes if body_bytes else None,
|
|
headers=raw_headers,
|
|
stream=is_stream,
|
|
)
|
|
elapsed_ms = int((time.monotonic() - start_time) * 1000)
|
|
|
|
# Handle 429 — cooldown and retry
|
|
if resp.status_code == 429:
|
|
new_count = backend.consecutive_429_count + 1
|
|
start_cooldown(backend.id, new_count)
|
|
|
|
resp_body = ""
|
|
try:
|
|
resp_body = resp.text[:200]
|
|
except Exception:
|
|
pass
|
|
|
|
logger.warning(
|
|
"backend_429_cooldown",
|
|
backend_id=backend.id,
|
|
pool=backend.pool,
|
|
consecutive=new_count,
|
|
model=canonical_model,
|
|
)
|
|
|
|
# Track the error
|
|
record_usage(
|
|
backend_id=backend.id,
|
|
model=canonical_model,
|
|
prompt_tokens=0,
|
|
completion_tokens=0,
|
|
cost=0.0,
|
|
latency_ms=elapsed_ms,
|
|
is_error=True,
|
|
)
|
|
|
|
continue # Retry with another backend
|
|
|
|
# Success — track usage
|
|
resp_json: dict[str, Any] = {}
|
|
try:
|
|
if not is_stream and resp.content:
|
|
resp_json = json.loads(resp.content)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
prompt_tokens, completion_tokens, total_tokens = extract_usage_from_response(
|
|
resp, resp_json, canonical_model
|
|
)
|
|
cost = calculate_cost(
|
|
backend, canonical_model, prompt_tokens, completion_tokens
|
|
)
|
|
|
|
record_usage(
|
|
backend_id=backend.id,
|
|
model=canonical_model,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
cost=cost,
|
|
latency_ms=elapsed_ms,
|
|
)
|
|
|
|
logger.info(
|
|
"request_completed",
|
|
backend_id=backend.id,
|
|
pool=backend.pool,
|
|
model=canonical_model,
|
|
status=resp.status_code,
|
|
tokens=total_tokens,
|
|
cost=round(cost, 6),
|
|
elapsed_ms=elapsed_ms,
|
|
)
|
|
|
|
return build_response(resp)
|
|
|
|
except httpx.TimeoutException:
|
|
logger.warning(
|
|
"backend_timeout",
|
|
backend_id=backend.id,
|
|
model=canonical_model,
|
|
)
|
|
continue
|
|
except (httpx.ConnectError, httpx.RemoteProtocolError) as exc:
|
|
logger.warning(
|
|
"backend_connection_error",
|
|
backend_id=backend.id,
|
|
model=canonical_model,
|
|
error=str(exc),
|
|
)
|
|
continue
|
|
except Exception as exc:
|
|
logger.error(
|
|
"proxy_error",
|
|
backend_id=backend.id,
|
|
model=canonical_model,
|
|
error=str(exc),
|
|
)
|
|
continue
|
|
|
|
# --- Primary-Wait: wait for primary pool recovery before fallback/emergency ---
|
|
pwl = logger.bind(phase="primary_wait")
|
|
for pw_attempt in range(config.primary_wait_max_retries):
|
|
await asyncio.sleep(config.primary_wait_ms / 1000.0)
|
|
_refresh_cooldowns()
|
|
|
|
backend = router.pick_primary_backend(canonical_model)
|
|
if not backend:
|
|
pwl.debug(
|
|
"primary_wait_no_backend",
|
|
attempt=pw_attempt + 1,
|
|
model=canonical_model,
|
|
)
|
|
continue
|
|
|
|
try:
|
|
resp = await forward_to_backend(
|
|
backend=backend,
|
|
method=request.method,
|
|
path=path,
|
|
body=body_bytes if body_bytes else None,
|
|
headers=raw_headers,
|
|
stream=is_stream,
|
|
)
|
|
elapsed_ms = int((time.monotonic() - start_time) * 1000)
|
|
|
|
if resp.status_code == 429:
|
|
new_count = backend.consecutive_429_count + 1
|
|
start_cooldown(backend.id, new_count)
|
|
pwl.warning(
|
|
"primary_wait_429",
|
|
backend_id=backend.id,
|
|
attempt=pw_attempt + 1,
|
|
consecutive=new_count,
|
|
model=canonical_model,
|
|
)
|
|
record_usage(
|
|
backend_id=backend.id,
|
|
model=canonical_model,
|
|
prompt_tokens=0,
|
|
completion_tokens=0,
|
|
cost=0.0,
|
|
latency_ms=elapsed_ms,
|
|
is_error=True,
|
|
)
|
|
continue
|
|
|
|
# Primary recovered — success
|
|
resp_json: dict[str, Any] = {}
|
|
try:
|
|
if not is_stream and resp.content:
|
|
resp_json = json.loads(resp.content)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
prompt_tokens, completion_tokens, total_tokens = extract_usage_from_response(
|
|
resp, resp_json, canonical_model
|
|
)
|
|
cost = calculate_cost(
|
|
backend, canonical_model, prompt_tokens, completion_tokens
|
|
)
|
|
|
|
record_usage(
|
|
backend_id=backend.id,
|
|
model=canonical_model,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
cost=cost,
|
|
latency_ms=elapsed_ms,
|
|
)
|
|
|
|
logger.info(
|
|
"primary_wait_recovery",
|
|
backend_id=backend.id,
|
|
pool=backend.pool,
|
|
model=canonical_model,
|
|
status=resp.status_code,
|
|
tokens=total_tokens,
|
|
cost=round(cost, 6),
|
|
elapsed_ms=elapsed_ms,
|
|
pw_attempt=pw_attempt + 1,
|
|
)
|
|
return build_response(resp)
|
|
|
|
except httpx.TimeoutException:
|
|
pwl.warning(
|
|
"primary_wait_timeout",
|
|
backend_id=backend.id,
|
|
attempt=pw_attempt + 1,
|
|
model=canonical_model,
|
|
)
|
|
except (httpx.ConnectError, httpx.RemoteProtocolError) as exc:
|
|
pwl.warning(
|
|
"primary_wait_connection_error",
|
|
backend_id=backend.id,
|
|
attempt=pw_attempt + 1,
|
|
model=canonical_model,
|
|
error=str(exc),
|
|
)
|
|
except Exception as exc:
|
|
pwl.error(
|
|
"primary_wait_error",
|
|
backend_id=backend.id,
|
|
attempt=pw_attempt + 1,
|
|
model=canonical_model,
|
|
error=str(exc),
|
|
)
|
|
continue
|
|
|
|
# All pools exhausted (including primary-wait retries) — emergency rate-limited passthrough
|
|
emergency_rpm = int(config.default_rpm_limit * config.emergency_rpm_fraction)
|
|
if emergency_rpm < 1:
|
|
emergency_rpm = 1
|
|
|
|
logger.warning(
|
|
"all_pools_exhausted_emergency",
|
|
model=canonical_model,
|
|
emergency_rpm=emergency_rpm,
|
|
)
|
|
|
|
# Track emergency activation for metrics
|
|
_emergency_count += 1
|
|
|
|
# Emergency: try to get a token from any fallback backend at reduced RPM
|
|
emergency_retries = 3
|
|
for attempt in range(emergency_retries):
|
|
backends = pool_manager.get_any_healthy_backends()
|
|
for backend in backends:
|
|
if rate_limiter.consume(backend.id, emergency_rpm):
|
|
try:
|
|
resp = await forward_to_backend(
|
|
backend=backend,
|
|
method=request.method,
|
|
path=path,
|
|
body=body_bytes if body_bytes else None,
|
|
headers=raw_headers,
|
|
stream=is_stream,
|
|
)
|
|
elapsed_ms = int((time.monotonic() - start_time) * 1000)
|
|
|
|
if resp.status_code == 429:
|
|
start_cooldown(backend.id, backend.consecutive_429_count + 1)
|
|
continue
|
|
|
|
# Success in emergency mode
|
|
try:
|
|
resp_json: dict[str, Any] = {}
|
|
if not is_stream and resp.content:
|
|
resp_json = json.loads(resp.content)
|
|
except Exception:
|
|
resp_json = {}
|
|
|
|
prompt_tokens, completion_tokens, total_tokens = extract_usage_from_response(
|
|
resp, resp_json, canonical_model
|
|
)
|
|
cost_em = calculate_cost(backend, canonical_model, prompt_tokens, completion_tokens)
|
|
|
|
record_usage(
|
|
backend_id=backend.id,
|
|
model=canonical_model,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
cost=cost_em,
|
|
latency_ms=elapsed_ms,
|
|
)
|
|
|
|
logger.info(
|
|
"emergency_passthrough_success",
|
|
backend_id=backend.id,
|
|
model=canonical_model,
|
|
emergency_rpm=emergency_rpm,
|
|
)
|
|
return build_response(resp)
|
|
except Exception:
|
|
continue
|
|
|
|
# All emergency attempts failed — return 503 for OpenClaw fallback chain
|
|
return build_error_response(
|
|
503,
|
|
"All provider pools exhausted. OpenClaw fallback chain should activate.",
|
|
"AllPoolsExhausted",
|
|
)
|
|
|
|
|
|
def _refresh_cooldowns() -> None:
|
|
"""Check and clear expired cooldowns for backends currently in cooling state.
|
|
|
|
Only queries backends with status='cooling' (the health_check_loop handles
|
|
the periodic scanning; this is the on-demand refresh before proxy routing)."""
|
|
from storage.backend_store import list_backends
|
|
backends = list_backends(decrypt_key=False)
|
|
for backend in backends:
|
|
if backend.status == "cooling":
|
|
check_and_clear_cooldown(backend.id) |