feat(scheduler): acquire/release cf-orch VRAM lease per batch worker

Before running a batch of tasks, the scheduler now requests a VRAM lease
from the cf-orch coordinator (POST /api/leases). The lease is held for the
full batch and released in the finally block so it's always cleaned up even
on error. Falls back gracefully if the coordinator is unreachable.

Adds coordinator_url and service_name params to TaskScheduler.__init__
and get_scheduler() so callers can override the default localhost:7700.
This commit is contained in:
pyr0ball 2026-04-01 11:06:16 -07:00
parent 67701f0d29
commit 6b8e421eb2

View file

@ -127,6 +127,9 @@ class TaskScheduler:
vram_budgets: dict[str, float],
available_vram_gb: Optional[float] = None,
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
coordinator_url: str = "http://localhost:7700",
service_name: str = "peregrine",
lease_priority: int = 2,
) -> None:
self._db_path = db_path
self._run_task = run_task_fn
@ -134,6 +137,10 @@ class TaskScheduler:
self._budgets: dict[str, float] = dict(vram_budgets)
self._max_queue_depth = max_queue_depth
self._coordinator_url = coordinator_url.rstrip("/")
self._service_name = service_name
self._lease_priority = lease_priority
self._lock = threading.Lock()
self._wake = threading.Event()
self._stop = threading.Event()
@ -240,8 +247,70 @@ class TaskScheduler:
self._reserved_vram += budget
thread.start()
def _acquire_lease(self, task_type: str) -> Optional[str]:
"""Request a VRAM lease from the coordinator. Returns lease_id or None."""
if httpx is None:
return None
budget_gb = self._budgets.get(task_type, 0.0)
if budget_gb <= 0:
return None
mb = int(budget_gb * 1024)
try:
# Pick the GPU with the most free VRAM on the first registered node
resp = httpx.get(f"{self._coordinator_url}/api/nodes", timeout=2.0)
if resp.status_code != 200:
return None
nodes = resp.json().get("nodes", [])
if not nodes:
return None
best_node = best_gpu = best_free = None
for node in nodes:
for gpu in node.get("gpus", []):
free = gpu.get("vram_free_mb", 0)
if best_free is None or free > best_free:
best_node = node["node_id"]
best_gpu = gpu["gpu_id"]
best_free = free
if best_node is None:
return None
lease_resp = httpx.post(
f"{self._coordinator_url}/api/leases",
json={
"node_id": best_node,
"gpu_id": best_gpu,
"mb": mb,
"service": self._service_name,
"priority": self._lease_priority,
},
timeout=3.0,
)
if lease_resp.status_code == 200:
lease_id = lease_resp.json()["lease"]["lease_id"]
logger.debug(
"Acquired VRAM lease %s for task_type=%s (%d MB)",
lease_id, task_type, mb,
)
return lease_id
except Exception as exc:
logger.debug("Lease acquire failed (non-fatal): %s", exc)
return None
def _release_lease(self, lease_id: str) -> None:
"""Release a coordinator VRAM lease. Best-effort; failures are logged only."""
if httpx is None or not lease_id:
return
try:
httpx.delete(
f"{self._coordinator_url}/api/leases/{lease_id}",
timeout=3.0,
)
logger.debug("Released VRAM lease %s", lease_id)
except Exception as exc:
logger.debug("Lease release failed (non-fatal): %s", exc)
def _batch_worker(self, task_type: str) -> None:
"""Serial consumer for one task type. Runs until the type's deque is empty."""
lease_id: Optional[str] = self._acquire_lease(task_type)
try:
while True:
with self._lock:
@ -253,6 +322,8 @@ class TaskScheduler:
self._db_path, task.id, task_type, task.job_id, task.params
)
finally:
if lease_id:
self._release_lease(lease_id)
with self._lock:
self._active.pop(task_type, None)
self._reserved_vram -= self._budgets.get(task_type, 0.0)
@ -298,6 +369,8 @@ def get_scheduler(
task_types: Optional[frozenset[str]] = None,
vram_budgets: Optional[dict[str, float]] = None,
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
coordinator_url: str = "http://localhost:7700",
service_name: str = "peregrine",
) -> TaskScheduler:
"""Return the process-level TaskScheduler singleton.
@ -324,6 +397,8 @@ def get_scheduler(
task_types=task_types,
vram_budgets=vram_budgets,
max_queue_depth=max_queue_depth,
coordinator_url=coordinator_url,
service_name=service_name,
)
candidate.start()
with _scheduler_lock: