nndeploy.server.task_queue 源代码

# queue.py

import copy
import heapq
import time
import threading
import queue as _queue
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Any, Dict, List, Optional

_MAX_HISTORY = 1000

[文档]class ExecutionStatus:
[文档] def __init__(self, ok: bool, msg: str = "", label: str | None = None): self.str = label or ("success" if ok else "failed") self.completed = ok self.messages = [msg] if msg else []
[文档]class TaskState(Enum): PENDING = auto() # in pq_ DISPATCHED = auto() # in job_mp_queue, not running RUNNING = auto() # running SUCCEEDED = auto() FAILED = auto() CANCELLED = auto()
[文档]@dataclass class TaskRecord: idx: int payload: Dict[str, Any] state: TaskState = TaskState.PENDING ts_submit: float = field(default_factory=time.time) ts_dispatch: Optional[float] = None ts_start: Optional[float] = None ts_finish: Optional[float] = None worker_pid: Optional[int] = None
[文档]class TaskQueue: """thread safe queue"""
[文档] def __init__(self, server: "NnDeployServer", job_mp_q: "mp.Queue"): self.server = server self._mtx = threading.RLock() self._not_empty = threading.Condition(self._mtx) self._counter = 0 self._pq: List[Any] = [] self._active: Dict[int, TaskRecord] = {} self._hist: Dict[str, Any] = {} self._job_q = job_mp_q
[文档] def put(self, payload, prio: int = 0): with self._mtx: heapq.heappush(self._pq, (prio, time.time(), payload)) self._not_empty.notify()
[文档] def get(self, timeout: Optional[float] = None): with self._not_empty: while not self._pq: if not self._not_empty.wait(timeout): return None prio, ts, payload = heapq.heappop(self._pq) idx = self._counter rec = TaskRecord(idx=idx, payload=copy.deepcopy(payload), state=TaskState.PENDING, ts_submit=ts) self._active[idx] = rec self._counter += 1 return idx, payload
[文档] def mark_dispatched(self, idx: int): with self._mtx: rec = self._active.get(idx) if not rec: return rec.state = TaskState.DISPATCHED rec.ts_dispatch = time.time()
[文档] def mark_started(self, task_id: str, worker_pid: Optional[int] = None): with self._mtx: target: Optional[TaskRecord] = None for rec in self._active.values(): if rec.payload.get("id") == task_id: target = rec break if not target: return target.state = TaskState.RUNNING target.ts_start = time.time() if worker_pid: target.worker_pid = worker_pid
[文档] def task_done(self, idx: int, status: ExecutionStatus, results: Dict, time_profile_map: Dict): with self._mtx: rec = self._active.pop(idx, None) if rec is None: return rec.ts_finish = time.time() final_state = TaskState.SUCCEEDED if status.completed else ( TaskState.CANCELLED if status.str == "cancelled" else TaskState.FAILED ) if len(self._hist) >= _MAX_HISTORY: self._hist.pop(next(iter(self._hist))) task_id = rec.payload.get("id") self._hist[task_id] = { "task": rec.payload, "status": status.__dict__, "state": final_state.name, "ts_submit": rec.ts_submit, "ts_dispatch": rec.ts_dispatch, "ts_start": rec.ts_start, "ts_finish": rec.ts_finish, "worker_pid": rec.worker_pid, "time_profile": time_profile_map, } self.server.notify_task_done(task_id, status, results, time_profile_map)
[文档] def get_current_queue(self): with self._mtx: running = [] dispatched = [] pending = [] for rec in self._active.values(): data = { "idx": rec.idx, "task": copy.deepcopy(rec.payload), "state": rec.state.name, "ts_submit": rec.ts_submit, "ts_dispatch": rec.ts_dispatch, "ts_start": rec.ts_start, "worker_pid": rec.worker_pid, } if rec.state == TaskState.RUNNING: running.append(data) elif rec.state == TaskState.DISPATCHED: dispatched.append(data) else: pending.append(data) pq_snapshot = [ (p, ts, copy.deepcopy(pl)) for (p, ts, pl) in self._pq ] return { "RUNNING": running, "DISPATCHED": dispatched, "PENDING": pq_snapshot }
[文档] def get_history(self, max_items: int | None = None): with self._mtx: items = list(self._hist.items())[-max_items:] if max_items else self._hist.items() return dict(items)
[文档] def get_task_by_id(self, task_id: str) -> Optional[dict]: with self._mtx: for rec in self._active.values(): if rec.payload.get("id") == task_id: return { "task": copy.deepcopy(rec.payload), "state": rec.state.name, "ts_submit": rec.ts_submit, "ts_dispatch": rec.ts_dispatch, "ts_start": rec.ts_start, "worker_pid": rec.worker_pid, } record = self._hist.get(task_id) return copy.deepcopy(record) if record else None
def _push_hist_cancelled_unlocked(self, idx: int, payload: dict, reason: str): rec = self._active.pop(idx, None) if len(self._hist) >= _MAX_HISTORY: self._hist.pop(next(iter(self._hist))) tid = payload.get("id") status = ExecutionStatus(ok=False, msg=reason, label="cancelled") self._hist[tid] = { "task": payload, "status": status.__dict__, "state": TaskState.CANCELLED.name, "ts_submit": rec.ts_submit if rec else None, "ts_dispatch": rec.ts_dispatch if rec else None, "ts_start": rec.ts_start if rec else None, "ts_finish": time.time(), "worker_pid": rec.worker_pid if rec else None, }
[文档] def clear_pending(self) -> int: with self._mtx: n = len(self._pq) self._pq.clear() return n
[文档] def drain_job_q(self) -> int: drained = 0 while True: try: idx, payload = self._job_q.get_nowait() except _queue.Empty: break except Exception: break else: with self._mtx: drained += 1 self._push_hist_cancelled_unlocked(idx, payload, reason="flushed from job_q") return drained
[文档] def flush(self) -> dict: cleared_pending = self.clear_pending() drained_jobq = self.drain_job_q() import time as _t for _ in range(2): _t.sleep(0.02) drained_jobq += self.drain_job_q() return {"cleared_pending": cleared_pending, "drained_job_q": drained_jobq}