"""AES-256-GCM encryption for API Key storage.""" import os import secrets import structlog from cryptography.hazmat.primitives.ciphers.aead import AESGCM logger = structlog.get_logger() _ENCRYPTION_KEY: bytes | None = None _cipher: AESGCM | None = None def init_crypto(hex_key: str) -> None: """Initialize the encryption module. Validates the key and prepares the cipher. Raises ValueError if key is invalid. """ global _ENCRYPTION_KEY, _cipher if not hex_key: raise ValueError("FATAL: SIDECAR_ENCRYPTION_KEY not set") if len(hex_key) != 64: raise ValueError( f"FATAL: SIDECAR_ENCRYPTION_KEY must be 64 hex chars (32 bytes), " f"got {len(hex_key)} chars" ) try: key_bytes = bytes.fromhex(hex_key) except ValueError: raise ValueError( "FATAL: SIDECAR_ENCRYPTION_KEY must be valid hexadecimal" ) global _ENCRYPTION_KEY, _cipher _ENCRYPTION_KEY = key_bytes _cipher = AESGCM(key_bytes) logger.info("crypto_initialized") def encrypt(plaintext: str) -> str: """Encrypt plaintext using AES-256-GCM. Returns: hex-encoded nonce (12 bytes) + ciphertext + tag. Format: : """ if _cipher is None: raise RuntimeError("Crypto not initialized. Call init_crypto() first.") nonce = secrets.token_bytes(12) ciphertext = _cipher.encrypt(nonce, plaintext.encode("utf-8"), None) return nonce.hex() + ":" + ciphertext.hex() def decrypt(encrypted: str) -> str: """Decrypt AES-256-GCM ciphertext. Args: encrypted: Format ":" Returns: Decrypted plaintext string. """ if _cipher is None: raise RuntimeError("Crypto not initialized. Call init_crypto() first.") parts = encrypted.split(":", 1) if len(parts) != 2: raise ValueError("Invalid encrypted format: expected nonce:ciphertext") nonce = bytes.fromhex(parts[0]) ciphertext = bytes.fromhex(parts[1]) try: plaintext = _cipher.decrypt(nonce, ciphertext, None) return plaintext.decode("utf-8") except Exception as e: raise ValueError(f"Decryption failed: {e}") def is_initialized() -> bool: """Check if crypto has been initialized.""" return _cipher is not None def mask_api_key(api_key_plain: str) -> str: """Mask API key for display: show first 6 + last 4 chars.""" if len(api_key_plain) <= 10: return api_key_plain[:2] + "****" return api_key_plain[:6] + "****" + api_key_plain[-4:] def try_decrypt_existing(encrypted_value: str) -> str | None: """Try to decrypt an existing encrypted value. Returns the plaintext if successful, None if decryption fails (e.g., encryption key was changed). """ try: return decrypt(encrypted_value) except Exception: logger.warning( "decrypt_existing_failed", hint="Encryption key may have been changed, existing keys unrecoverable" ) return None