diff --git a/circuitforge_core/tasks/scheduler.py b/circuitforge_core/tasks/scheduler.py index a6c4453..754c542 100644 --- a/circuitforge_core/tasks/scheduler.py +++ b/circuitforge_core/tasks/scheduler.py @@ -127,6 +127,9 @@ class TaskScheduler: vram_budgets: dict[str, float], available_vram_gb: Optional[float] = None, max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH, + coordinator_url: str = "http://localhost:7700", + service_name: str = "peregrine", + lease_priority: int = 2, ) -> None: self._db_path = db_path self._run_task = run_task_fn @@ -134,6 +137,10 @@ class TaskScheduler: self._budgets: dict[str, float] = dict(vram_budgets) self._max_queue_depth = max_queue_depth + self._coordinator_url = coordinator_url.rstrip("/") + self._service_name = service_name + self._lease_priority = lease_priority + self._lock = threading.Lock() self._wake = threading.Event() self._stop = threading.Event() @@ -240,8 +247,70 @@ class TaskScheduler: self._reserved_vram += budget thread.start() + def _acquire_lease(self, task_type: str) -> Optional[str]: + """Request a VRAM lease from the coordinator. Returns lease_id or None.""" + if httpx is None: + return None + budget_gb = self._budgets.get(task_type, 0.0) + if budget_gb <= 0: + return None + mb = int(budget_gb * 1024) + try: + # Pick the GPU with the most free VRAM on the first registered node + resp = httpx.get(f"{self._coordinator_url}/api/nodes", timeout=2.0) + if resp.status_code != 200: + return None + nodes = resp.json().get("nodes", []) + if not nodes: + return None + best_node = best_gpu = best_free = None + for node in nodes: + for gpu in node.get("gpus", []): + free = gpu.get("vram_free_mb", 0) + if best_free is None or free > best_free: + best_node = node["node_id"] + best_gpu = gpu["gpu_id"] + best_free = free + if best_node is None: + return None + lease_resp = httpx.post( + f"{self._coordinator_url}/api/leases", + json={ + "node_id": best_node, + "gpu_id": best_gpu, + "mb": mb, + "service": self._service_name, + "priority": self._lease_priority, + }, + timeout=3.0, + ) + if lease_resp.status_code == 200: + lease_id = lease_resp.json()["lease"]["lease_id"] + logger.debug( + "Acquired VRAM lease %s for task_type=%s (%d MB)", + lease_id, task_type, mb, + ) + return lease_id + except Exception as exc: + logger.debug("Lease acquire failed (non-fatal): %s", exc) + return None + + def _release_lease(self, lease_id: str) -> None: + """Release a coordinator VRAM lease. Best-effort; failures are logged only.""" + if httpx is None or not lease_id: + return + try: + httpx.delete( + f"{self._coordinator_url}/api/leases/{lease_id}", + timeout=3.0, + ) + logger.debug("Released VRAM lease %s", lease_id) + except Exception as exc: + logger.debug("Lease release failed (non-fatal): %s", exc) + def _batch_worker(self, task_type: str) -> None: """Serial consumer for one task type. Runs until the type's deque is empty.""" + lease_id: Optional[str] = self._acquire_lease(task_type) try: while True: with self._lock: @@ -253,6 +322,8 @@ class TaskScheduler: self._db_path, task.id, task_type, task.job_id, task.params ) finally: + if lease_id: + self._release_lease(lease_id) with self._lock: self._active.pop(task_type, None) self._reserved_vram -= self._budgets.get(task_type, 0.0) @@ -298,6 +369,8 @@ def get_scheduler( task_types: Optional[frozenset[str]] = None, vram_budgets: Optional[dict[str, float]] = None, max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH, + coordinator_url: str = "http://localhost:7700", + service_name: str = "peregrine", ) -> TaskScheduler: """Return the process-level TaskScheduler singleton. @@ -324,6 +397,8 @@ def get_scheduler( task_types=task_types, vram_budgets=vram_budgets, max_queue_depth=max_queue_depth, + coordinator_url=coordinator_url, + service_name=service_name, ) candidate.start() with _scheduler_lock: