Files
EnterpriseArchitect/services/nvidia_sidecar/priority_queue.py
T

253 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
NVIDIA Sidecar 限流代理 — 四级优先级请求队列模块 (§3.3)
管理待处理的 NVIDIA API 请求,按优先级 + FIFO 出队。
支持三种队列满策略:PASSTHROUGH / REJECT / DROP_LOWEST。
"""
from __future__ import annotations
import asyncio
import heapq
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from nvidia_sidecar.rate_limiter import Priority
# ---------------------------------------------------------------------------
# 队列满策略
# ---------------------------------------------------------------------------
class QueueFullPolicy(str, Enum):
"""队列满时的处理策略。"""
PASSTHROUGH = "passthrough" # 直通上游,绕过排队(fail-open 子策略)
REJECT = "reject" # 返回 503 Service Unavailable
DROP_LOWEST = "drop_lowest" # 丢弃队列中最低优先级元素,插入新请求
# ---------------------------------------------------------------------------
# 队列元素
# ---------------------------------------------------------------------------
@dataclass(order=True)
class PriorityQueueItem:
"""优先级队列元素。
``sort_index`` 由 ``(priority, timestamp)`` 组成,
Python 的 ``__lt__`` 按字段顺序比较:先比 priority,再比 timestamp。
数值越小越优先(URGENT=1 优于 HIGH=2)。
"""
sort_index: tuple[int, float] = field(compare=True)
priority: Priority = field(compare=False)
request_id: str = field(compare=False)
payload: dict[str, Any] = field(compare=False)
enqueued_at: float = field(compare=False)
headers: dict[str, str] = field(default_factory=dict, compare=False)
# ---------------------------------------------------------------------------
# 优先级请求队列
# ---------------------------------------------------------------------------
class QueueFullError(Exception):
"""队列已满且策略为 REJECT 时抛出。"""
pass
class QueueFullPassthrough(Exception):
"""队列已满且策略为 PASSTHROUGH 时抛出,由调用方绕过队列直通上游。"""
pass
class PriorityRequestQueue:
"""异步线程安全的四级优先级请求队列。
内部使用 ``asyncio.Lock`` 保护并发操作,
基于 ``heapq`` + ``asyncio.Event`` 实现阻塞出队。
"""
def __init__(self, max_size: int = 500) -> None:
"""初始化优先级队列。
Args:
max_size: 队列最大容量。
Raises:
ValueError: max_size <= 0。
"""
if max_size <= 0:
raise ValueError(f"max_size 必须为正整数,当前值: {max_size}")
self.max_size: int = max_size
self._heap: list[PriorityQueueItem] = []
self._lock: asyncio.Lock = asyncio.Lock()
self._not_empty: asyncio.Event = asyncio.Event()
self._full_policy: QueueFullPolicy = QueueFullPolicy.PASSTHROUGH
# 统计
self._total_enqueued: int = 0
self._total_dequeued: int = 0
self._total_dropped: int = 0
# ---- 队列满策略 ----
def set_full_policy(self, policy: QueueFullPolicy) -> None:
"""设置队列满时的处理策略。
Args:
policy: QueueFullPolicy 枚举值。
"""
self._full_policy = policy
@property
def full_policy(self) -> QueueFullPolicy:
"""当前队列满策略。"""
return self._full_policy
# ---- 动态容量调整 ----
def set_max_size(self, new_size: int) -> tuple[bool, str]:
"""动态调整队列最大容量(热重载)。
缩小操作受保护:如果 new_size 小于当前排队数,拒绝变更并
提示当前队列深度。
Args:
new_size: 新的最大容量。
Returns:
(成功标志, 消息)。成功时标志为 True,消息含新旧容量对比;
失败时标志为 False,消息含拒绝原因和当前深度。
Raises:
ValueError: new_size <= 0。
"""
if new_size <= 0:
raise ValueError(f"max_size 必须为正整数,当前值: {new_size}")
current = len(self._heap)
if new_size < current:
return (False, f"拒绝缩小:新上限 {new_size} < 当前排队数 {current},需要先排空或提升上限")
old = self.max_size
self.max_size = new_size
return (True, f"队列上限已调整:{old}{new_size}{'(当前排队 ' + str(current) + '' if current > 0 else ''}")
# ---- 入队 ----
async def put(
self,
item: dict[str, Any],
priority: Priority = Priority.NORMAL,
headers: dict[str, str] | None = None,
) -> str:
"""将请求放入队列。
Args:
item: 请求体(JSON 序列化的 dict)。
priority: 请求优先级,默认 NORMAL。
headers: 原始请求 headers。
Returns:
分配的唯一 request_id。
Raises:
QueueFullError: 队列满且策略为 REJECT。
"""
request_id = str(uuid.uuid4())
headers = headers or {}
queue_item = PriorityQueueItem(
sort_index=(int(priority), time.monotonic()),
priority=priority,
request_id=request_id,
payload=item,
enqueued_at=time.monotonic(),
headers=headers,
)
async with self._lock:
queue_size = len(self._heap)
if queue_size >= self.max_size:
if self._full_policy == QueueFullPolicy.REJECT:
raise QueueFullError(
f"队列已满 ({queue_size}/{self.max_size}),策略: reject"
)
elif self._full_policy == QueueFullPolicy.DROP_LOWEST:
# 丢弃 heap 中优先级最低(值最大)的元素
# heap 是最小堆,找最大值需要遍历
max_val_item = max(self._heap, key=lambda x: x.sort_index)
self._heap.remove(max_val_item)
heapq.heapify(self._heap)
self._total_dropped += 1
# PASSTHROUGH 策略:不插入队列,抛异常让调用方绕过排队
else:
raise QueueFullPassthrough(
f"队列已满 ({queue_size}/{self.max_size}),策略: passthrough"
)
heapq.heappush(self._heap, queue_item)
self._total_enqueued += 1
self._not_empty.set()
return request_id
# ---- 出队 ----
async def get(self, timeout: float = 1.0) -> PriorityQueueItem | None:
"""从队列取出下一个元素(阻塞、优先级排序)。
Args:
timeout: 阻塞等待的最大秒数,默认 1.0。
Returns:
优先级最高的队列元素;超时无元素时返回 None。
"""
deadline = time.monotonic() + timeout
while True:
async with self._lock:
if self._heap:
item = heapq.heappop(self._heap)
self._total_dequeued += 1
if not self._heap:
self._not_empty.clear()
return item
# 队列为空,等待新元素入队
remaining = deadline - time.monotonic()
if remaining <= 0:
return None
try:
await asyncio.wait_for(
self._not_empty.wait(),
timeout=remaining,
)
except asyncio.TimeoutError:
return None
# ---- 状态查询 ----
async def get_queue_size(self) -> int:
"""返回当前队列长度。"""
async with self._lock:
return len(self._heap)
async def get_stats(self) -> dict[str, Any]:
"""返回队列统计信息。"""
async with self._lock:
depth_by_priority: dict[str, int] = {}
for item in self._heap:
key = item.priority.name
depth_by_priority[key] = depth_by_priority.get(key, 0) + 1
return {
"max_size": self.max_size,
"current_size": len(self._heap),
"total_enqueued": self._total_enqueued,
"total_dequeued": self._total_dequeued,
"total_dropped": self._total_dropped,
"depth_by_priority": depth_by_priority,
"full_policy": self._full_policy.value,
"utilization": len(self._heap) / self.max_size if self.max_size > 0 else 0.0,
}