diff --git a/circuitforge_core/resources/cli.py b/circuitforge_core/resources/cli.py index 1e907a4..9e65f08 100644 --- a/circuitforge_core/resources/cli.py +++ b/circuitforge_core/resources/cli.py @@ -32,9 +32,14 @@ def start( profile: Annotated[Optional[Path], typer.Option(help="Profile YAML path")] = None, host: str = "0.0.0.0", port: int = 7700, + node_id: str = "local", agent_port: int = 7701, ) -> None: - """Start the cf-orch coordinator (auto-detects GPU profile if not specified).""" + """Start the cf-orch coordinator (auto-detects GPU profile if not specified). + + Automatically pre-registers the local agent so its GPUs appear on the + dashboard immediately. Remote nodes self-register via POST /api/nodes. + """ from circuitforge_core.resources.coordinator.lease_manager import LeaseManager from circuitforge_core.resources.coordinator.profile_registry import ProfileRegistry from circuitforge_core.resources.coordinator.agent_supervisor import AgentSupervisor @@ -52,8 +57,6 @@ def start( "Warning: no GPUs detected via nvidia-smi — coordinator running with 0 VRAM" ) else: - for gpu in gpus: - lease_manager.register_gpu("local", gpu.gpu_id, gpu.vram_total_mb) typer.echo(f"Detected {len(gpus)} GPU(s)") if profile: @@ -67,6 +70,11 @@ def start( ) typer.echo(f"Auto-selected profile: {active_profile.name}") + # Pre-register the local agent — the heartbeat loop will poll it for live GPU data. + local_agent_url = f"http://127.0.0.1:{agent_port}" + supervisor.register(node_id, local_agent_url) + typer.echo(f"Registered local node '{node_id}' → {local_agent_url}") + coordinator_app = create_coordinator_app( lease_manager=lease_manager, profile_registry=profile_registry, @@ -83,10 +91,47 @@ def agent( node_id: str = "local", host: str = "0.0.0.0", port: int = 7701, + advertise_host: Optional[str] = None, ) -> None: - """Start a cf-orch node agent (for remote nodes like Navi, Huginn).""" + """Start a cf-orch node agent and self-register with the coordinator. + + The agent starts its HTTP server, then POSTs its URL to the coordinator + so it appears on the dashboard without manual configuration. + + Use --advertise-host to override the IP the coordinator should use to + reach this agent (e.g. on a multi-homed or NATted host). + """ + import asyncio + import threading + import httpx from circuitforge_core.resources.agent.app import create_agent_app + # The URL the coordinator should use to reach this agent. + reach_host = advertise_host or ("127.0.0.1" if host in ("0.0.0.0", "::") else host) + agent_url = f"http://{reach_host}:{port}" + + def _register_in_background() -> None: + """POST registration to coordinator after a short delay (uvicorn needs ~1s to bind).""" + import time + time.sleep(2.0) + try: + resp = httpx.post( + f"{coordinator}/api/nodes", + json={"node_id": node_id, "agent_url": agent_url}, + timeout=5.0, + ) + if resp.is_success: + typer.echo(f"Registered with coordinator at {coordinator} as '{node_id}'") + else: + typer.echo( + f"Warning: coordinator registration returned {resp.status_code}", err=True + ) + except Exception as exc: + typer.echo(f"Warning: could not reach coordinator at {coordinator}: {exc}", err=True) + + # Fire registration in a daemon thread so uvicorn.run() can start blocking immediately. + threading.Thread(target=_register_in_background, daemon=True).start() + agent_app = create_agent_app(node_id=node_id) typer.echo(f"Starting cf-orch agent [{node_id}] on {host}:{port}") uvicorn.run(agent_app, host=host, port=port) diff --git a/circuitforge_core/resources/coordinator/app.py b/circuitforge_core/resources/coordinator/app.py index 6c9961b..abf333f 100644 --- a/circuitforge_core/resources/coordinator/app.py +++ b/circuitforge_core/resources/coordinator/app.py @@ -1,5 +1,6 @@ from __future__ import annotations +from contextlib import asynccontextmanager from pathlib import Path from typing import Any @@ -7,13 +8,13 @@ from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse from pydantic import BaseModel -_DASHBOARD_HTML = (Path(__file__).parent / "dashboard.html").read_text() - from circuitforge_core.resources.coordinator.agent_supervisor import AgentSupervisor from circuitforge_core.resources.coordinator.eviction_engine import EvictionEngine from circuitforge_core.resources.coordinator.lease_manager import LeaseManager from circuitforge_core.resources.coordinator.profile_registry import ProfileRegistry +_DASHBOARD_HTML = (Path(__file__).parent / "dashboard.html").read_text() + class LeaseRequest(BaseModel): node_id: str @@ -24,6 +25,11 @@ class LeaseRequest(BaseModel): ttl_s: float = 0.0 +class NodeRegisterRequest(BaseModel): + node_id: str + agent_url: str # e.g. "http://10.1.10.71:7701" + + def create_coordinator_app( lease_manager: LeaseManager, profile_registry: ProfileRegistry, @@ -31,7 +37,15 @@ def create_coordinator_app( ) -> FastAPI: eviction_engine = EvictionEngine(lease_manager=lease_manager) - app = FastAPI(title="cf-orch-coordinator") + @asynccontextmanager + async def _lifespan(app: FastAPI): # type: ignore[type-arg] + import asyncio + task = asyncio.create_task(agent_supervisor.run_heartbeat_loop()) + yield + agent_supervisor.stop() + task.cancel() + + app = FastAPI(title="cf-orch-coordinator", lifespan=_lifespan) @app.get("/", response_class=HTMLResponse, include_in_schema=False) def dashboard() -> HTMLResponse: @@ -65,6 +79,13 @@ def create_coordinator_app( ] } + @app.post("/api/nodes") + async def register_node(req: NodeRegisterRequest) -> dict[str, Any]: + """Agents call this to self-register. Coordinator immediately polls for GPU info.""" + agent_supervisor.register(req.node_id, req.agent_url) + await agent_supervisor.poll_agent(req.node_id) + return {"registered": True, "node_id": req.node_id} + @app.get("/api/profiles") def get_profiles() -> dict[str, Any]: return { diff --git a/circuitforge_core/tasks/scheduler.py b/circuitforge_core/tasks/scheduler.py index a6c4453..4cb3347 100644 --- a/circuitforge_core/tasks/scheduler.py +++ b/circuitforge_core/tasks/scheduler.py @@ -127,6 +127,9 @@ class TaskScheduler: vram_budgets: dict[str, float], available_vram_gb: Optional[float] = None, max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH, + coordinator_url: str = "http://localhost:7700", + service_name: str = "peregrine", + lease_priority: int = 2, ) -> None: self._db_path = db_path self._run_task = run_task_fn @@ -134,6 +137,10 @@ class TaskScheduler: self._budgets: dict[str, float] = dict(vram_budgets) self._max_queue_depth = max_queue_depth + self._coordinator_url = coordinator_url.rstrip("/") + self._service_name = service_name + self._lease_priority = lease_priority + self._lock = threading.Lock() self._wake = threading.Event() self._stop = threading.Event() @@ -196,11 +203,21 @@ class TaskScheduler: self._wake.set() def shutdown(self, timeout: float = 5.0) -> None: - """Signal the scheduler to stop and wait for it to exit.""" + """Signal the scheduler to stop and wait for it to exit. + + Joins both the scheduler loop thread and any active batch worker + threads so callers can rely on clean state (e.g. _reserved_vram == 0) + immediately after this returns. + """ self._stop.set() self._wake.set() if self._thread and self._thread.is_alive(): self._thread.join(timeout=timeout) + # Join active batch workers so _reserved_vram is settled on return + with self._lock: + workers = list(self._active.values()) + for worker in workers: + worker.join(timeout=timeout) def _scheduler_loop(self) -> None: while not self._stop.is_set(): @@ -240,8 +257,70 @@ class TaskScheduler: self._reserved_vram += budget thread.start() + def _acquire_lease(self, task_type: str) -> Optional[str]: + """Request a VRAM lease from the coordinator. Returns lease_id or None.""" + if httpx is None: + return None + budget_gb = self._budgets.get(task_type, 0.0) + if budget_gb <= 0: + return None + mb = int(budget_gb * 1024) + try: + # Pick the GPU with the most free VRAM on the first registered node + resp = httpx.get(f"{self._coordinator_url}/api/nodes", timeout=2.0) + if resp.status_code != 200: + return None + nodes = resp.json().get("nodes", []) + if not nodes: + return None + best_node = best_gpu = best_free = None + for node in nodes: + for gpu in node.get("gpus", []): + free = gpu.get("vram_free_mb", 0) + if best_free is None or free > best_free: + best_node = node["node_id"] + best_gpu = gpu["gpu_id"] + best_free = free + if best_node is None: + return None + lease_resp = httpx.post( + f"{self._coordinator_url}/api/leases", + json={ + "node_id": best_node, + "gpu_id": best_gpu, + "mb": mb, + "service": self._service_name, + "priority": self._lease_priority, + }, + timeout=3.0, + ) + if lease_resp.status_code == 200: + lease_id = lease_resp.json()["lease"]["lease_id"] + logger.debug( + "Acquired VRAM lease %s for task_type=%s (%d MB)", + lease_id, task_type, mb, + ) + return lease_id + except Exception as exc: + logger.debug("Lease acquire failed (non-fatal): %s", exc) + return None + + def _release_lease(self, lease_id: str) -> None: + """Release a coordinator VRAM lease. Best-effort; failures are logged only.""" + if httpx is None or not lease_id: + return + try: + httpx.delete( + f"{self._coordinator_url}/api/leases/{lease_id}", + timeout=3.0, + ) + logger.debug("Released VRAM lease %s", lease_id) + except Exception as exc: + logger.debug("Lease release failed (non-fatal): %s", exc) + def _batch_worker(self, task_type: str) -> None: """Serial consumer for one task type. Runs until the type's deque is empty.""" + lease_id: Optional[str] = self._acquire_lease(task_type) try: while True: with self._lock: @@ -253,6 +332,8 @@ class TaskScheduler: self._db_path, task.id, task_type, task.job_id, task.params ) finally: + if lease_id: + self._release_lease(lease_id) with self._lock: self._active.pop(task_type, None) self._reserved_vram -= self._budgets.get(task_type, 0.0) @@ -298,6 +379,8 @@ def get_scheduler( task_types: Optional[frozenset[str]] = None, vram_budgets: Optional[dict[str, float]] = None, max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH, + coordinator_url: str = "http://localhost:7700", + service_name: str = "peregrine", ) -> TaskScheduler: """Return the process-level TaskScheduler singleton. @@ -324,6 +407,8 @@ def get_scheduler( task_types=task_types, vram_budgets=vram_budgets, max_queue_depth=max_queue_depth, + coordinator_url=coordinator_url, + service_name=service_name, ) candidate.start() with _scheduler_lock: