feat(scheduler): acquire/release cf-orch VRAM lease per batch worker
Before running a batch of tasks, the scheduler now requests a VRAM lease from the cf-orch coordinator (POST /api/leases). The lease is held for the full batch and released in the finally block so it's always cleaned up even on error. Falls back gracefully if the coordinator is unreachable. Adds coordinator_url and service_name params to TaskScheduler.__init__ and get_scheduler() so callers can override the default localhost:7700.
This commit is contained in:
parent
67701f0d29
commit
6b8e421eb2
1 changed files with 75 additions and 0 deletions
|
|
@ -127,6 +127,9 @@ class TaskScheduler:
|
|||
vram_budgets: dict[str, float],
|
||||
available_vram_gb: Optional[float] = None,
|
||||
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
|
||||
coordinator_url: str = "http://localhost:7700",
|
||||
service_name: str = "peregrine",
|
||||
lease_priority: int = 2,
|
||||
) -> None:
|
||||
self._db_path = db_path
|
||||
self._run_task = run_task_fn
|
||||
|
|
@ -134,6 +137,10 @@ class TaskScheduler:
|
|||
self._budgets: dict[str, float] = dict(vram_budgets)
|
||||
self._max_queue_depth = max_queue_depth
|
||||
|
||||
self._coordinator_url = coordinator_url.rstrip("/")
|
||||
self._service_name = service_name
|
||||
self._lease_priority = lease_priority
|
||||
|
||||
self._lock = threading.Lock()
|
||||
self._wake = threading.Event()
|
||||
self._stop = threading.Event()
|
||||
|
|
@ -240,8 +247,70 @@ class TaskScheduler:
|
|||
self._reserved_vram += budget
|
||||
thread.start()
|
||||
|
||||
def _acquire_lease(self, task_type: str) -> Optional[str]:
|
||||
"""Request a VRAM lease from the coordinator. Returns lease_id or None."""
|
||||
if httpx is None:
|
||||
return None
|
||||
budget_gb = self._budgets.get(task_type, 0.0)
|
||||
if budget_gb <= 0:
|
||||
return None
|
||||
mb = int(budget_gb * 1024)
|
||||
try:
|
||||
# Pick the GPU with the most free VRAM on the first registered node
|
||||
resp = httpx.get(f"{self._coordinator_url}/api/nodes", timeout=2.0)
|
||||
if resp.status_code != 200:
|
||||
return None
|
||||
nodes = resp.json().get("nodes", [])
|
||||
if not nodes:
|
||||
return None
|
||||
best_node = best_gpu = best_free = None
|
||||
for node in nodes:
|
||||
for gpu in node.get("gpus", []):
|
||||
free = gpu.get("vram_free_mb", 0)
|
||||
if best_free is None or free > best_free:
|
||||
best_node = node["node_id"]
|
||||
best_gpu = gpu["gpu_id"]
|
||||
best_free = free
|
||||
if best_node is None:
|
||||
return None
|
||||
lease_resp = httpx.post(
|
||||
f"{self._coordinator_url}/api/leases",
|
||||
json={
|
||||
"node_id": best_node,
|
||||
"gpu_id": best_gpu,
|
||||
"mb": mb,
|
||||
"service": self._service_name,
|
||||
"priority": self._lease_priority,
|
||||
},
|
||||
timeout=3.0,
|
||||
)
|
||||
if lease_resp.status_code == 200:
|
||||
lease_id = lease_resp.json()["lease"]["lease_id"]
|
||||
logger.debug(
|
||||
"Acquired VRAM lease %s for task_type=%s (%d MB)",
|
||||
lease_id, task_type, mb,
|
||||
)
|
||||
return lease_id
|
||||
except Exception as exc:
|
||||
logger.debug("Lease acquire failed (non-fatal): %s", exc)
|
||||
return None
|
||||
|
||||
def _release_lease(self, lease_id: str) -> None:
|
||||
"""Release a coordinator VRAM lease. Best-effort; failures are logged only."""
|
||||
if httpx is None or not lease_id:
|
||||
return
|
||||
try:
|
||||
httpx.delete(
|
||||
f"{self._coordinator_url}/api/leases/{lease_id}",
|
||||
timeout=3.0,
|
||||
)
|
||||
logger.debug("Released VRAM lease %s", lease_id)
|
||||
except Exception as exc:
|
||||
logger.debug("Lease release failed (non-fatal): %s", exc)
|
||||
|
||||
def _batch_worker(self, task_type: str) -> None:
|
||||
"""Serial consumer for one task type. Runs until the type's deque is empty."""
|
||||
lease_id: Optional[str] = self._acquire_lease(task_type)
|
||||
try:
|
||||
while True:
|
||||
with self._lock:
|
||||
|
|
@ -253,6 +322,8 @@ class TaskScheduler:
|
|||
self._db_path, task.id, task_type, task.job_id, task.params
|
||||
)
|
||||
finally:
|
||||
if lease_id:
|
||||
self._release_lease(lease_id)
|
||||
with self._lock:
|
||||
self._active.pop(task_type, None)
|
||||
self._reserved_vram -= self._budgets.get(task_type, 0.0)
|
||||
|
|
@ -298,6 +369,8 @@ def get_scheduler(
|
|||
task_types: Optional[frozenset[str]] = None,
|
||||
vram_budgets: Optional[dict[str, float]] = None,
|
||||
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
|
||||
coordinator_url: str = "http://localhost:7700",
|
||||
service_name: str = "peregrine",
|
||||
) -> TaskScheduler:
|
||||
"""Return the process-level TaskScheduler singleton.
|
||||
|
||||
|
|
@ -324,6 +397,8 @@ def get_scheduler(
|
|||
task_types=task_types,
|
||||
vram_budgets=vram_budgets,
|
||||
max_queue_depth=max_queue_depth,
|
||||
coordinator_url=coordinator_url,
|
||||
service_name=service_name,
|
||||
)
|
||||
candidate.start()
|
||||
with _scheduler_lock:
|
||||
|
|
|
|||
Loading…
Reference in a new issue