"""Proxy request handling for Sidecar V2 — multi-pool routing + cooldown + rate limiting.""" import asyncio import json import time from typing import Any, Optional import httpx import structlog from fastapi import Request from fastapi.responses import JSONResponse, Response, StreamingResponse from config import config from pool_manager import PoolManager from rate_limiter import PerBackendRateLimiter from router import Router from cooldown_manager import start_cooldown, check_and_clear_cooldown from storage.models import Backend from storage.usage_store import record_usage # Emergency activation counter (read by metrics endpoint) _emergency_count: int = 0 def get_emergency_count() -> int: return _emergency_count logger: structlog.stdlib.BoundLogger = structlog.get_logger("sidecar_v2.proxy") def extract_model(body: dict[str, Any]) -> str: """Extract model identifier from request body.""" return str(body.get("model", "unknown")) def build_error_response(status: int, message: str, error_type: str = "") -> JSONResponse: """Build a standard error response.""" return JSONResponse( status_code=status, content={ "error": { "message": message, "type": error_type or f"Error_{status}", } }, ) async def forward_to_backend( backend: Backend, method: str, path: str, body: bytes | None, headers: dict[str, str], stream: bool = False, ) -> httpx.Response: """Forward a request to a specific backend.""" # Strip API version prefix from path when base URL already has one base = backend.api_base_url.rstrip("/") # Check if the last path segment of base matches the first segment of path base_parts = base.split("/") path_parts = path.split("/") if len(base_parts) >= 2 and len(path_parts) >= 2: base_last = base_parts[-1] path_first = path_parts[1] if len(path_parts) > 1 and path_parts[0] == "" else path_parts[0] if base_last == path_first: # Remove duplicate prefix from path path = "/" + "/".join(path_parts[2:]) upstream_url = base + path forward_headers = { k: v for k, v in headers.items() if k.lower() not in ("host", "content-length", "transfer-encoding") } if backend.api_key_plain: forward_headers["authorization"] = f"Bearer {backend.api_key_plain}" elif "authorization" not in {k.lower() for k in forward_headers}: forward_headers["authorization"] = "Bearer nvidia" timeout = httpx.Timeout(backend.timeout_seconds) async with httpx.AsyncClient(timeout=timeout) as client: req = client.build_request( method=method, url=upstream_url, headers=forward_headers, content=body, ) return await client.send(req, stream=stream) def build_response(resp: httpx.Response) -> Response: """Convert httpx.Response to FastAPI Response.""" content_type = resp.headers.get("content-type", "") headers = { k: v for k, v in resp.headers.items() if k.lower() not in ("content-encoding", "transfer-encoding") } is_sse = "text/event-stream" in content_type is_chunked = resp.headers.get("transfer-encoding", "").lower() == "chunked" if is_sse or (is_chunked and headers.get("content-type", "") != "application/octet-stream"): return StreamingResponse( content=resp.aiter_bytes(), status_code=resp.status_code, headers=headers, media_type=content_type or "text/event-stream", ) return Response( content=resp.content, status_code=resp.status_code, headers=headers, media_type=content_type or "application/json", ) def extract_usage_from_response( resp: httpx.Response, resp_json: dict[str, Any], model: str, ) -> tuple[int, int, int]: """Extract token usage from response body (OpenAI-compatible).""" usage = resp_json.get("usage", {}) prompt_tokens = usage.get("prompt_tokens", 0) or 0 completion_tokens = usage.get("completion_tokens", 0) or 0 # Try streaming chunks: aggregate from choices if not prompt_tokens and not completion_tokens: choices = resp_json.get("choices", []) for choice in choices: if isinstance(choice, dict): tokens = choice.get("usage", {}) prompt_tokens += tokens.get("prompt_tokens", 0) or 0 completion_tokens += tokens.get("completion_tokens", 0) or 0 total_tokens = prompt_tokens + completion_tokens if total_tokens == 0: total_tokens = usage.get("total_tokens", 0) or 0 return prompt_tokens, completion_tokens, total_tokens def calculate_cost( backend: Backend, model: str, prompt_tokens: int, completion_tokens: int, ) -> float: """Calculate cost using backend's model pricing.""" cost_info = backend.get_model_cost(model) input_cost = cost_info.get("input", 0.0) output_cost = cost_info.get("output", 0.0) # Costs are per token return (prompt_tokens * input_cost + completion_tokens * output_cost) async def handle_proxy_request( pool_manager: PoolManager, rate_limiter: PerBackendRateLimiter, router: Router, request: Request, path: str, ) -> Response: """Main proxy handler: multi-pool routing with cooldown and rate limiting. Flow: 1. Extract model → canonical name 2. Pick backend via Router (primary → fallback) 3. Forward request 4. If 429 → cooldown backend, retry with another 5. If all pools exhausted → emergency mode 6. Track usage """ global _emergency_count start_time = time.monotonic() body_bytes: bytes = await request.body() raw_headers: dict[str, str] = dict(request.headers) body_json: dict[str, Any] = {} try: if body_bytes: parsed = json.loads(body_bytes) if isinstance(parsed, dict): body_json = parsed except (ValueError, TypeError): body_json = {} canonical_model = extract_model(body_json) is_stream = body_json.get("stream", False) # Try with pool routing max_retries = config.max_pool_retries for attempt in range(max_retries): # Check and clear expired cooldowns before picking _refresh_cooldowns() backend = router.pick_backend(canonical_model) if backend is None: break # No backend available, fall through to emergency try: resp = await forward_to_backend( backend=backend, method=request.method, path=path, body=body_bytes if body_bytes else None, headers=raw_headers, stream=is_stream, ) elapsed_ms = int((time.monotonic() - start_time) * 1000) # Handle 429 — cooldown and retry if resp.status_code == 429: new_count = backend.consecutive_429_count + 1 start_cooldown(backend.id, new_count) resp_body = "" try: resp_body = resp.text[:200] except Exception: pass logger.warning( "backend_429_cooldown", backend_id=backend.id, pool=backend.pool, consecutive=new_count, model=canonical_model, ) # Track the error record_usage( backend_id=backend.id, model=canonical_model, prompt_tokens=0, completion_tokens=0, cost=0.0, latency_ms=elapsed_ms, is_error=True, ) continue # Retry with another backend # Success — track usage resp_json: dict[str, Any] = {} try: if not is_stream and resp.content: resp_json = json.loads(resp.content) except (ValueError, TypeError): pass prompt_tokens, completion_tokens, total_tokens = extract_usage_from_response( resp, resp_json, canonical_model ) cost = calculate_cost( backend, canonical_model, prompt_tokens, completion_tokens ) record_usage( backend_id=backend.id, model=canonical_model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, cost=cost, latency_ms=elapsed_ms, ) logger.info( "request_completed", backend_id=backend.id, pool=backend.pool, model=canonical_model, status=resp.status_code, tokens=total_tokens, cost=round(cost, 6), elapsed_ms=elapsed_ms, ) return build_response(resp) except httpx.TimeoutException: logger.warning( "backend_timeout", backend_id=backend.id, model=canonical_model, ) continue except (httpx.ConnectError, httpx.RemoteProtocolError) as exc: logger.warning( "backend_connection_error", backend_id=backend.id, model=canonical_model, error=str(exc), ) continue except Exception as exc: logger.error( "proxy_error", backend_id=backend.id, model=canonical_model, error=str(exc), ) continue # --- Primary-Wait: wait for primary pool recovery before fallback/emergency --- pwl = logger.bind(phase="primary_wait") for pw_attempt in range(config.primary_wait_max_retries): await asyncio.sleep(config.primary_wait_ms / 1000.0) _refresh_cooldowns() backend = router.pick_primary_backend(canonical_model) if not backend: pwl.debug( "primary_wait_no_backend", attempt=pw_attempt + 1, model=canonical_model, ) continue try: resp = await forward_to_backend( backend=backend, method=request.method, path=path, body=body_bytes if body_bytes else None, headers=raw_headers, stream=is_stream, ) elapsed_ms = int((time.monotonic() - start_time) * 1000) if resp.status_code == 429: new_count = backend.consecutive_429_count + 1 start_cooldown(backend.id, new_count) pwl.warning( "primary_wait_429", backend_id=backend.id, attempt=pw_attempt + 1, consecutive=new_count, model=canonical_model, ) record_usage( backend_id=backend.id, model=canonical_model, prompt_tokens=0, completion_tokens=0, cost=0.0, latency_ms=elapsed_ms, is_error=True, ) continue # Primary recovered — success resp_json: dict[str, Any] = {} try: if not is_stream and resp.content: resp_json = json.loads(resp.content) except (ValueError, TypeError): pass prompt_tokens, completion_tokens, total_tokens = extract_usage_from_response( resp, resp_json, canonical_model ) cost = calculate_cost( backend, canonical_model, prompt_tokens, completion_tokens ) record_usage( backend_id=backend.id, model=canonical_model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, cost=cost, latency_ms=elapsed_ms, ) logger.info( "primary_wait_recovery", backend_id=backend.id, pool=backend.pool, model=canonical_model, status=resp.status_code, tokens=total_tokens, cost=round(cost, 6), elapsed_ms=elapsed_ms, pw_attempt=pw_attempt + 1, ) return build_response(resp) except httpx.TimeoutException: pwl.warning( "primary_wait_timeout", backend_id=backend.id, attempt=pw_attempt + 1, model=canonical_model, ) except (httpx.ConnectError, httpx.RemoteProtocolError) as exc: pwl.warning( "primary_wait_connection_error", backend_id=backend.id, attempt=pw_attempt + 1, model=canonical_model, error=str(exc), ) except Exception as exc: pwl.error( "primary_wait_error", backend_id=backend.id, attempt=pw_attempt + 1, model=canonical_model, error=str(exc), ) continue # All pools exhausted (including primary-wait retries) — emergency rate-limited passthrough emergency_rpm = int(config.default_rpm_limit * config.emergency_rpm_fraction) if emergency_rpm < 1: emergency_rpm = 1 logger.warning( "all_pools_exhausted_emergency", model=canonical_model, emergency_rpm=emergency_rpm, ) # Track emergency activation for metrics _emergency_count += 1 # Emergency: try to get a token from any fallback backend at reduced RPM emergency_retries = 3 for attempt in range(emergency_retries): backends = pool_manager.get_any_healthy_backends() for backend in backends: if rate_limiter.consume(backend.id, emergency_rpm): try: resp = await forward_to_backend( backend=backend, method=request.method, path=path, body=body_bytes if body_bytes else None, headers=raw_headers, stream=is_stream, ) elapsed_ms = int((time.monotonic() - start_time) * 1000) if resp.status_code == 429: start_cooldown(backend.id, backend.consecutive_429_count + 1) continue # Success in emergency mode try: resp_json: dict[str, Any] = {} if not is_stream and resp.content: resp_json = json.loads(resp.content) except Exception: resp_json = {} prompt_tokens, completion_tokens, total_tokens = extract_usage_from_response( resp, resp_json, canonical_model ) cost_em = calculate_cost(backend, canonical_model, prompt_tokens, completion_tokens) record_usage( backend_id=backend.id, model=canonical_model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, cost=cost_em, latency_ms=elapsed_ms, ) logger.info( "emergency_passthrough_success", backend_id=backend.id, model=canonical_model, emergency_rpm=emergency_rpm, ) return build_response(resp) except Exception: continue # All emergency attempts failed — return 503 for OpenClaw fallback chain return build_error_response( 503, "All provider pools exhausted. OpenClaw fallback chain should activate.", "AllPoolsExhausted", ) def _refresh_cooldowns() -> None: """Check and clear expired cooldowns for backends currently in cooling state. Only queries backends with status='cooling' (the health_check_loop handles the periodic scanning; this is the on-demand refresh before proxy routing).""" from storage.backend_store import list_backends backends = list_backends(decrypt_key=False) for backend in backends: if backend.status == "cooling": check_and_clear_cooldown(backend.id)