feat: cf-orch agent registration + VRAM lease wiring

Merges feature/orch-agent-registration into main.

- Agent self-registration and coordinator heartbeat loop
- TaskScheduler acquires/releases cf-orch VRAM lease per batch worker
- shutdown() now joins batch worker threads for clean teardown
- 94 tests passing
This commit is contained in:
pyr0ball 2026-04-01 11:21:38 -07:00
commit 427182aae7
3 changed files with 159 additions and 8 deletions

View file

@ -32,9 +32,14 @@ def start(
profile: Annotated[Optional[Path], typer.Option(help="Profile YAML path")] = None, profile: Annotated[Optional[Path], typer.Option(help="Profile YAML path")] = None,
host: str = "0.0.0.0", host: str = "0.0.0.0",
port: int = 7700, port: int = 7700,
node_id: str = "local",
agent_port: int = 7701, agent_port: int = 7701,
) -> None: ) -> None:
"""Start the cf-orch coordinator (auto-detects GPU profile if not specified).""" """Start the cf-orch coordinator (auto-detects GPU profile if not specified).
Automatically pre-registers the local agent so its GPUs appear on the
dashboard immediately. Remote nodes self-register via POST /api/nodes.
"""
from circuitforge_core.resources.coordinator.lease_manager import LeaseManager from circuitforge_core.resources.coordinator.lease_manager import LeaseManager
from circuitforge_core.resources.coordinator.profile_registry import ProfileRegistry from circuitforge_core.resources.coordinator.profile_registry import ProfileRegistry
from circuitforge_core.resources.coordinator.agent_supervisor import AgentSupervisor from circuitforge_core.resources.coordinator.agent_supervisor import AgentSupervisor
@ -52,8 +57,6 @@ def start(
"Warning: no GPUs detected via nvidia-smi — coordinator running with 0 VRAM" "Warning: no GPUs detected via nvidia-smi — coordinator running with 0 VRAM"
) )
else: else:
for gpu in gpus:
lease_manager.register_gpu("local", gpu.gpu_id, gpu.vram_total_mb)
typer.echo(f"Detected {len(gpus)} GPU(s)") typer.echo(f"Detected {len(gpus)} GPU(s)")
if profile: if profile:
@ -67,6 +70,11 @@ def start(
) )
typer.echo(f"Auto-selected profile: {active_profile.name}") typer.echo(f"Auto-selected profile: {active_profile.name}")
# Pre-register the local agent — the heartbeat loop will poll it for live GPU data.
local_agent_url = f"http://127.0.0.1:{agent_port}"
supervisor.register(node_id, local_agent_url)
typer.echo(f"Registered local node '{node_id}'{local_agent_url}")
coordinator_app = create_coordinator_app( coordinator_app = create_coordinator_app(
lease_manager=lease_manager, lease_manager=lease_manager,
profile_registry=profile_registry, profile_registry=profile_registry,
@ -83,10 +91,47 @@ def agent(
node_id: str = "local", node_id: str = "local",
host: str = "0.0.0.0", host: str = "0.0.0.0",
port: int = 7701, port: int = 7701,
advertise_host: Optional[str] = None,
) -> None: ) -> None:
"""Start a cf-orch node agent (for remote nodes like Navi, Huginn).""" """Start a cf-orch node agent and self-register with the coordinator.
The agent starts its HTTP server, then POSTs its URL to the coordinator
so it appears on the dashboard without manual configuration.
Use --advertise-host to override the IP the coordinator should use to
reach this agent (e.g. on a multi-homed or NATted host).
"""
import asyncio
import threading
import httpx
from circuitforge_core.resources.agent.app import create_agent_app from circuitforge_core.resources.agent.app import create_agent_app
# The URL the coordinator should use to reach this agent.
reach_host = advertise_host or ("127.0.0.1" if host in ("0.0.0.0", "::") else host)
agent_url = f"http://{reach_host}:{port}"
def _register_in_background() -> None:
"""POST registration to coordinator after a short delay (uvicorn needs ~1s to bind)."""
import time
time.sleep(2.0)
try:
resp = httpx.post(
f"{coordinator}/api/nodes",
json={"node_id": node_id, "agent_url": agent_url},
timeout=5.0,
)
if resp.is_success:
typer.echo(f"Registered with coordinator at {coordinator} as '{node_id}'")
else:
typer.echo(
f"Warning: coordinator registration returned {resp.status_code}", err=True
)
except Exception as exc:
typer.echo(f"Warning: could not reach coordinator at {coordinator}: {exc}", err=True)
# Fire registration in a daemon thread so uvicorn.run() can start blocking immediately.
threading.Thread(target=_register_in_background, daemon=True).start()
agent_app = create_agent_app(node_id=node_id) agent_app = create_agent_app(node_id=node_id)
typer.echo(f"Starting cf-orch agent [{node_id}] on {host}:{port}") typer.echo(f"Starting cf-orch agent [{node_id}] on {host}:{port}")
uvicorn.run(agent_app, host=host, port=port) uvicorn.run(agent_app, host=host, port=port)

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -7,13 +8,13 @@ from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from pydantic import BaseModel from pydantic import BaseModel
_DASHBOARD_HTML = (Path(__file__).parent / "dashboard.html").read_text()
from circuitforge_core.resources.coordinator.agent_supervisor import AgentSupervisor from circuitforge_core.resources.coordinator.agent_supervisor import AgentSupervisor
from circuitforge_core.resources.coordinator.eviction_engine import EvictionEngine from circuitforge_core.resources.coordinator.eviction_engine import EvictionEngine
from circuitforge_core.resources.coordinator.lease_manager import LeaseManager from circuitforge_core.resources.coordinator.lease_manager import LeaseManager
from circuitforge_core.resources.coordinator.profile_registry import ProfileRegistry from circuitforge_core.resources.coordinator.profile_registry import ProfileRegistry
_DASHBOARD_HTML = (Path(__file__).parent / "dashboard.html").read_text()
class LeaseRequest(BaseModel): class LeaseRequest(BaseModel):
node_id: str node_id: str
@ -24,6 +25,11 @@ class LeaseRequest(BaseModel):
ttl_s: float = 0.0 ttl_s: float = 0.0
class NodeRegisterRequest(BaseModel):
node_id: str
agent_url: str # e.g. "http://10.1.10.71:7701"
def create_coordinator_app( def create_coordinator_app(
lease_manager: LeaseManager, lease_manager: LeaseManager,
profile_registry: ProfileRegistry, profile_registry: ProfileRegistry,
@ -31,7 +37,15 @@ def create_coordinator_app(
) -> FastAPI: ) -> FastAPI:
eviction_engine = EvictionEngine(lease_manager=lease_manager) eviction_engine = EvictionEngine(lease_manager=lease_manager)
app = FastAPI(title="cf-orch-coordinator") @asynccontextmanager
async def _lifespan(app: FastAPI): # type: ignore[type-arg]
import asyncio
task = asyncio.create_task(agent_supervisor.run_heartbeat_loop())
yield
agent_supervisor.stop()
task.cancel()
app = FastAPI(title="cf-orch-coordinator", lifespan=_lifespan)
@app.get("/", response_class=HTMLResponse, include_in_schema=False) @app.get("/", response_class=HTMLResponse, include_in_schema=False)
def dashboard() -> HTMLResponse: def dashboard() -> HTMLResponse:
@ -65,6 +79,13 @@ def create_coordinator_app(
] ]
} }
@app.post("/api/nodes")
async def register_node(req: NodeRegisterRequest) -> dict[str, Any]:
"""Agents call this to self-register. Coordinator immediately polls for GPU info."""
agent_supervisor.register(req.node_id, req.agent_url)
await agent_supervisor.poll_agent(req.node_id)
return {"registered": True, "node_id": req.node_id}
@app.get("/api/profiles") @app.get("/api/profiles")
def get_profiles() -> dict[str, Any]: def get_profiles() -> dict[str, Any]:
return { return {

View file

@ -127,6 +127,9 @@ class TaskScheduler:
vram_budgets: dict[str, float], vram_budgets: dict[str, float],
available_vram_gb: Optional[float] = None, available_vram_gb: Optional[float] = None,
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH, max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
coordinator_url: str = "http://localhost:7700",
service_name: str = "peregrine",
lease_priority: int = 2,
) -> None: ) -> None:
self._db_path = db_path self._db_path = db_path
self._run_task = run_task_fn self._run_task = run_task_fn
@ -134,6 +137,10 @@ class TaskScheduler:
self._budgets: dict[str, float] = dict(vram_budgets) self._budgets: dict[str, float] = dict(vram_budgets)
self._max_queue_depth = max_queue_depth self._max_queue_depth = max_queue_depth
self._coordinator_url = coordinator_url.rstrip("/")
self._service_name = service_name
self._lease_priority = lease_priority
self._lock = threading.Lock() self._lock = threading.Lock()
self._wake = threading.Event() self._wake = threading.Event()
self._stop = threading.Event() self._stop = threading.Event()
@ -196,11 +203,21 @@ class TaskScheduler:
self._wake.set() self._wake.set()
def shutdown(self, timeout: float = 5.0) -> None: def shutdown(self, timeout: float = 5.0) -> None:
"""Signal the scheduler to stop and wait for it to exit.""" """Signal the scheduler to stop and wait for it to exit.
Joins both the scheduler loop thread and any active batch worker
threads so callers can rely on clean state (e.g. _reserved_vram == 0)
immediately after this returns.
"""
self._stop.set() self._stop.set()
self._wake.set() self._wake.set()
if self._thread and self._thread.is_alive(): if self._thread and self._thread.is_alive():
self._thread.join(timeout=timeout) self._thread.join(timeout=timeout)
# Join active batch workers so _reserved_vram is settled on return
with self._lock:
workers = list(self._active.values())
for worker in workers:
worker.join(timeout=timeout)
def _scheduler_loop(self) -> None: def _scheduler_loop(self) -> None:
while not self._stop.is_set(): while not self._stop.is_set():
@ -240,8 +257,70 @@ class TaskScheduler:
self._reserved_vram += budget self._reserved_vram += budget
thread.start() thread.start()
def _acquire_lease(self, task_type: str) -> Optional[str]:
"""Request a VRAM lease from the coordinator. Returns lease_id or None."""
if httpx is None:
return None
budget_gb = self._budgets.get(task_type, 0.0)
if budget_gb <= 0:
return None
mb = int(budget_gb * 1024)
try:
# Pick the GPU with the most free VRAM on the first registered node
resp = httpx.get(f"{self._coordinator_url}/api/nodes", timeout=2.0)
if resp.status_code != 200:
return None
nodes = resp.json().get("nodes", [])
if not nodes:
return None
best_node = best_gpu = best_free = None
for node in nodes:
for gpu in node.get("gpus", []):
free = gpu.get("vram_free_mb", 0)
if best_free is None or free > best_free:
best_node = node["node_id"]
best_gpu = gpu["gpu_id"]
best_free = free
if best_node is None:
return None
lease_resp = httpx.post(
f"{self._coordinator_url}/api/leases",
json={
"node_id": best_node,
"gpu_id": best_gpu,
"mb": mb,
"service": self._service_name,
"priority": self._lease_priority,
},
timeout=3.0,
)
if lease_resp.status_code == 200:
lease_id = lease_resp.json()["lease"]["lease_id"]
logger.debug(
"Acquired VRAM lease %s for task_type=%s (%d MB)",
lease_id, task_type, mb,
)
return lease_id
except Exception as exc:
logger.debug("Lease acquire failed (non-fatal): %s", exc)
return None
def _release_lease(self, lease_id: str) -> None:
"""Release a coordinator VRAM lease. Best-effort; failures are logged only."""
if httpx is None or not lease_id:
return
try:
httpx.delete(
f"{self._coordinator_url}/api/leases/{lease_id}",
timeout=3.0,
)
logger.debug("Released VRAM lease %s", lease_id)
except Exception as exc:
logger.debug("Lease release failed (non-fatal): %s", exc)
def _batch_worker(self, task_type: str) -> None: def _batch_worker(self, task_type: str) -> None:
"""Serial consumer for one task type. Runs until the type's deque is empty.""" """Serial consumer for one task type. Runs until the type's deque is empty."""
lease_id: Optional[str] = self._acquire_lease(task_type)
try: try:
while True: while True:
with self._lock: with self._lock:
@ -253,6 +332,8 @@ class TaskScheduler:
self._db_path, task.id, task_type, task.job_id, task.params self._db_path, task.id, task_type, task.job_id, task.params
) )
finally: finally:
if lease_id:
self._release_lease(lease_id)
with self._lock: with self._lock:
self._active.pop(task_type, None) self._active.pop(task_type, None)
self._reserved_vram -= self._budgets.get(task_type, 0.0) self._reserved_vram -= self._budgets.get(task_type, 0.0)
@ -298,6 +379,8 @@ def get_scheduler(
task_types: Optional[frozenset[str]] = None, task_types: Optional[frozenset[str]] = None,
vram_budgets: Optional[dict[str, float]] = None, vram_budgets: Optional[dict[str, float]] = None,
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH, max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
coordinator_url: str = "http://localhost:7700",
service_name: str = "peregrine",
) -> TaskScheduler: ) -> TaskScheduler:
"""Return the process-level TaskScheduler singleton. """Return the process-level TaskScheduler singleton.
@ -324,6 +407,8 @@ def get_scheduler(
task_types=task_types, task_types=task_types,
vram_budgets=vram_budgets, vram_budgets=vram_budgets,
max_queue_depth=max_queue_depth, max_queue_depth=max_queue_depth,
coordinator_url=coordinator_url,
service_name=service_name,
) )
candidate.start() candidate.start()
with _scheduler_lock: with _scheduler_lock: