"""CRUD operations for Backend (provider) management.""" import json import time from typing import Optional from storage.db import get_connection, generate_id from storage.models import Backend, ModelMapping from crypto import encrypt, decrypt def create_backend(backend: Backend) -> Backend: """Create a new backend. Encrypts API key before storage.""" if not backend.id: backend.id = generate_id("bkd") now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) backend.created_at = now backend.updated_at = now api_key_encrypted = encrypt(backend.api_key_plain) with get_connection() as conn: conn.execute( """INSERT INTO backends (id, name, label, api_base_url, api_key_encrypted, api, timeout_seconds, rpm_limit, pool, enabled, status, model_mappings_json, source, cooldown_until, consecutive_429_count, metadata_json, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", ( backend.id, backend.name, backend.label, backend.api_base_url, api_key_encrypted, backend.api, backend.timeout_seconds, backend.rpm_limit, backend.pool, 1 if backend.enabled else 0, backend.status, json.dumps(_mappings_to_dict(backend.model_mappings)), backend.source, backend.cooldown_until, backend.consecutive_429_count, json.dumps(backend.metadata), backend.created_at, backend.updated_at, ), ) conn.commit() return backend def get_backend(backend_id: str, decrypt_key: bool = True) -> Optional[Backend]: """Get a single backend by ID.""" with get_connection() as conn: row = conn.execute( "SELECT * FROM backends WHERE id = ?", (backend_id,) ).fetchone() if row is None: return None return _row_to_backend(row, decrypt_key=decrypt_key) def list_backends( pool: Optional[str] = None, enabled_only: bool = False, decrypt_key: bool = False, ) -> list[Backend]: """List backends, optionally filtered by pool.""" with get_connection() as conn: if pool: rows = conn.execute( "SELECT * FROM backends WHERE pool = ? ORDER BY created_at", (pool,), ).fetchall() else: rows = conn.execute( "SELECT * FROM backends ORDER BY pool, created_at" ).fetchall() backends = [_row_to_backend(r, decrypt_key=decrypt_key) for r in rows] if enabled_only: backends = [b for b in backends if b.enabled] return backends def update_backend(backend_id: str, updates: dict) -> Optional[Backend]: """Update backend fields. If api_key_plain is provided, re-encrypt.""" current = get_backend(backend_id, decrypt_key=True) if current is None: return None # Apply updates allowed = { "name", "label", "api_base_url", "api", "timeout_seconds", "rpm_limit", "pool", "enabled", "status", "source", "cooldown_until", "consecutive_429_count", "metadata", } for key, value in updates.items(): if key in allowed: setattr(current, key, value) current.updated_at = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) # Handle API key update api_key_encrypted = None if "api_key_plain" in updates and updates["api_key_plain"]: current.api_key_plain = updates["api_key_plain"] api_key_encrypted = encrypt(updates["api_key_plain"]) # Handle model_mappings update mappings_json = None if "model_mappings" in updates: current.model_mappings = updates["model_mappings"] mappings_json = json.dumps(_mappings_to_dict(current.model_mappings)) with get_connection() as conn: # Build dynamic UPDATE set_clauses = [ "name = ?", "label = ?", "api_base_url = ?", "api = ?", "timeout_seconds = ?", "rpm_limit = ?", "pool = ?", "enabled = ?", "status = ?", "source = ?", "cooldown_until = ?", "consecutive_429_count = ?", "metadata_json = ?", "updated_at = ?", ] params = [ current.name, current.label, current.api_base_url, current.api, current.timeout_seconds, current.rpm_limit, current.pool, 1 if current.enabled else 0, current.status, current.source, current.cooldown_until, current.consecutive_429_count, json.dumps(current.metadata), current.updated_at, ] if api_key_encrypted: set_clauses.append("api_key_encrypted = ?") params.append(api_key_encrypted) if mappings_json is not None: set_clauses.append("model_mappings_json = ?") params.append(mappings_json) params.append(backend_id) conn.execute( f"UPDATE backends SET {', '.join(set_clauses)} WHERE id = ?", params, ) conn.commit() return get_backend(backend_id, decrypt_key=False) def delete_backend(backend_id: str) -> bool: """Delete a backend. Returns True if deleted.""" with get_connection() as conn: cursor = conn.execute("DELETE FROM backends WHERE id = ?", (backend_id,)) conn.commit() return cursor.rowcount > 0 def set_backend_status(backend_id: str, status: str) -> bool: """Quickly set backend status (healthy/cooling/error/disabled).""" now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) with get_connection() as conn: cursor = conn.execute( "UPDATE backends SET status = ?, updated_at = ? WHERE id = ?", (status, now, backend_id), ) conn.commit() return cursor.rowcount > 0 def set_backend_cooldown(backend_id: str, cooldown_until: str, count: int) -> bool: """Set cooldown state on a backend.""" now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) with get_connection() as conn: cursor = conn.execute( """UPDATE backends SET status = 'cooling', cooldown_until = ?, consecutive_429_count = ?, updated_at = ? WHERE id = ?""", (cooldown_until, count, now, backend_id), ) conn.commit() return cursor.rowcount > 0 def clear_backend_cooldown(backend_id: str) -> bool: """Clear cooldown (back to healthy).""" now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) with get_connection() as conn: cursor = conn.execute( """UPDATE backends SET status = 'healthy', cooldown_until = NULL, consecutive_429_count = 0, updated_at = ? WHERE id = ?""", (now, backend_id), ) conn.commit() return cursor.rowcount > 0 def get_pool_stats() -> dict: """Get summary stats per pool.""" with get_connection() as conn: rows = conn.execute( """SELECT pool, COUNT(*) as total, SUM(CASE WHEN enabled = 1 THEN 1 ELSE 0 END) as enabled, SUM(CASE WHEN status = 'healthy' THEN 1 ELSE 0 END) as healthy, SUM(CASE WHEN status = 'cooling' THEN 1 ELSE 0 END) as cooling, SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as error FROM backends GROUP BY pool""" ).fetchall() stats = {} for row in rows: stats[row["pool"]] = { "total": row["total"], "enabled": row["enabled"], "healthy": row["healthy"], "cooling": row["cooling"], "error": row["error"], } return stats def _row_to_backend(row, decrypt_key: bool = True) -> Backend: """Convert a DB row to a Backend instance.""" mappings_raw = row["model_mappings_json"] or "{}" mappings_dict = json.loads(mappings_raw) model_mappings = {} for canonical_name, mm in mappings_dict.items(): model_mappings[canonical_name] = ModelMapping.from_dict(mm) backend = Backend( id=row["id"], name=row["name"], label=row["label"], api_base_url=row["api_base_url"], api_key_encrypted=row["api_key_encrypted"] or "", api=row["api"], timeout_seconds=row["timeout_seconds"], rpm_limit=row["rpm_limit"], pool=row["pool"], enabled=bool(row["enabled"]), status=row["status"], model_mappings=model_mappings, source=row["source"], cooldown_until=row["cooldown_until"], consecutive_429_count=row["consecutive_429_count"], metadata=json.loads(row["metadata_json"] or "{}"), created_at=row["created_at"], updated_at=row["updated_at"], ) if decrypt_key and backend.api_key_encrypted: from crypto import try_decrypt_existing plain = try_decrypt_existing(backend.api_key_encrypted) if plain: backend.api_key_plain = plain return backend def _mappings_to_dict(mappings: dict[str, ModelMapping]) -> dict: """Convert ModelMapping dict to JSON-safe dict.""" return {k: v.to_dict() for k, v in mappings.items()}