diff --git a/circuitforge_core/resources/cli.py b/circuitforge_core/resources/cli.py index dc6883f..7238507 100644 --- a/circuitforge_core/resources/cli.py +++ b/circuitforge_core/resources/cli.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import sys from pathlib import Path from typing import Annotated, Optional @@ -7,6 +8,8 @@ from typing import Annotated, Optional import typer import uvicorn +logger = logging.getLogger(__name__) + app = typer.Typer(name="cf-orch", help="CircuitForge GPU resource orchestrator") _SYSTEMD_UNIT_PATH = Path("/etc/systemd/system/cf-orch.service") @@ -47,14 +50,21 @@ def start( from circuitforge_core.resources.coordinator.service_registry import ServiceRegistry from circuitforge_core.resources.agent.gpu_monitor import GpuMonitor + from circuitforge_core.resources.coordinator.node_store import NodeStore + lease_manager = LeaseManager() profile_registry = ProfileRegistry() service_registry = ServiceRegistry() + node_store = NodeStore() supervisor = AgentSupervisor( lease_manager=lease_manager, service_registry=service_registry, profile_registry=profile_registry, + node_store=node_store, ) + restored = supervisor.restore_from_store() + if restored: + typer.echo(f"Restored {restored} known node(s) from previous session") monitor = GpuMonitor() gpus = monitor.poll() @@ -119,27 +129,43 @@ def agent( reach_host = advertise_host or ("127.0.0.1" if host in ("0.0.0.0", "::") else host) agent_url = f"http://{reach_host}:{port}" - def _register_in_background() -> None: - """POST registration to coordinator after a short delay (uvicorn needs ~1s to bind).""" - import time - time.sleep(2.0) - try: - resp = httpx.post( - f"{coordinator}/api/nodes", - json={"node_id": node_id, "agent_url": agent_url}, - timeout=5.0, - ) - if resp.is_success: - typer.echo(f"Registered with coordinator at {coordinator} as '{node_id}'") - else: - typer.echo( - f"Warning: coordinator registration returned {resp.status_code}", err=True - ) - except Exception as exc: - typer.echo(f"Warning: could not reach coordinator at {coordinator}: {exc}", err=True) + _RECONNECT_INTERVAL_S = 30.0 - # Fire registration in a daemon thread so uvicorn.run() can start blocking immediately. - threading.Thread(target=_register_in_background, daemon=True).start() + def _reconnect_loop() -> None: + """ + Persistently re-register this agent with the coordinator. + + Runs as a daemon thread for the lifetime of the agent process: + - Waits 2 s on first run (uvicorn needs time to bind) + - Re-registers every 30 s thereafter + - If the coordinator is down, silently retries — no crashing + - When the coordinator restarts, the agent re-appears within one cycle + + This means coordinator restarts require no manual intervention on agent hosts. + """ + import time + first = True + while True: + time.sleep(2.0 if first else _RECONNECT_INTERVAL_S) + first = False + try: + resp = httpx.post( + f"{coordinator}/api/nodes", + json={"node_id": node_id, "agent_url": agent_url}, + timeout=5.0, + ) + if resp.is_success: + logger.debug("Registered with coordinator at %s as '%s'", coordinator, node_id) + else: + logger.warning( + "Coordinator registration returned %s", resp.status_code + ) + except Exception as exc: + logger.debug("Coordinator at %s unreachable, will retry: %s", coordinator, exc) + + # Fire reconnect loop in a daemon thread so uvicorn.run() can start blocking immediately. + threading.Thread(target=_reconnect_loop, daemon=True, name="cf-orch-reconnect").start() + typer.echo(f"Reconnect loop started — will register with {coordinator} every {int(_RECONNECT_INTERVAL_S)}s") service_manager = None try: diff --git a/circuitforge_core/resources/coordinator/agent_supervisor.py b/circuitforge_core/resources/coordinator/agent_supervisor.py index 8536636..503c8c5 100644 --- a/circuitforge_core/resources/coordinator/agent_supervisor.py +++ b/circuitforge_core/resources/coordinator/agent_supervisor.py @@ -8,6 +8,7 @@ from dataclasses import dataclass, field import httpx from circuitforge_core.resources.coordinator.lease_manager import LeaseManager +from circuitforge_core.resources.coordinator.node_store import NodeStore from circuitforge_core.resources.coordinator.profile_registry import ProfileRegistry from circuitforge_core.resources.coordinator.service_registry import ServiceRegistry from circuitforge_core.resources.models import GpuInfo, NodeInfo, ResidentAllocation @@ -33,14 +34,38 @@ class AgentSupervisor: lease_manager: LeaseManager, service_registry: ServiceRegistry | None = None, profile_registry: ProfileRegistry | None = None, + node_store: NodeStore | None = None, ) -> None: self._agents: dict[str, AgentRecord] = {} self._lease_manager = lease_manager self._running = False self._service_registry = service_registry self._profile_registry = profile_registry + self._node_store = node_store self._heartbeat_tick = 0 + def restore_from_store(self) -> int: + """ + Load previously-known nodes from NodeStore into the in-memory registry. + + All restored nodes start as offline=False. The heartbeat loop will poll + them on its first tick and promote any that respond to online=True. + + Returns the number of nodes restored. + """ + if self._node_store is None: + return 0 + restored = 0 + for node_id, agent_url in self._node_store.all(): + if node_id not in self._agents: + self._agents[node_id] = AgentRecord( + node_id=node_id, agent_url=agent_url, online=False + ) + restored += 1 + if restored: + logger.info("NodeStore: restored %d known node(s) from previous session", restored) + return restored + def register(self, node_id: str, agent_url: str) -> None: if node_id not in self._agents: self._agents[node_id] = AgentRecord(node_id=node_id, agent_url=agent_url) @@ -49,6 +74,8 @@ class AgentSupervisor: if self._agents[node_id].agent_url != agent_url: self._agents[node_id].agent_url = agent_url logger.info("Updated agent URL for %s → %s", node_id, agent_url) + if self._node_store is not None: + self._node_store.upsert(node_id, agent_url) def get_node_info(self, node_id: str) -> NodeInfo | None: record = self._agents.get(node_id) diff --git a/circuitforge_core/resources/coordinator/node_store.py b/circuitforge_core/resources/coordinator/node_store.py new file mode 100644 index 0000000..8dc71f9 --- /dev/null +++ b/circuitforge_core/resources/coordinator/node_store.py @@ -0,0 +1,85 @@ +""" +circuitforge_core.resources.coordinator.node_store — SQLite persistence for known agent nodes. + +Gives the coordinator restart-safe memory of which nodes have ever registered. +On startup the coordinator reloads all known nodes and immediately probes them; +nodes that respond come back online within one heartbeat cycle (~10 s) without +any manual intervention on the agent hosts. +""" +from __future__ import annotations + +import logging +import sqlite3 +import time +from pathlib import Path + +logger = logging.getLogger(__name__) + +_DEFAULT_DB_PATH = Path.home() / ".local" / "share" / "circuitforge" / "cf-orch-nodes.db" +_STALE_AGE_DAYS = 30 # nodes unseen for this long are pruned automatically + + +class NodeStore: + """ + Thin SQLite wrapper for persisting known agent nodes across coordinator restarts. + + Thread-safe for single-writer use (coordinator runs in one asyncio thread). + """ + + def __init__(self, db_path: Path = _DEFAULT_DB_PATH) -> None: + self.db_path = db_path + db_path.parent.mkdir(parents=True, exist_ok=True) + self._conn = sqlite3.connect(str(db_path), check_same_thread=False) + self._conn.row_factory = sqlite3.Row + self._migrate() + logger.debug("NodeStore initialised at %s", db_path) + + def _migrate(self) -> None: + self._conn.executescript(""" + CREATE TABLE IF NOT EXISTS known_nodes ( + node_id TEXT PRIMARY KEY, + agent_url TEXT NOT NULL, + last_seen REAL NOT NULL + ); + """) + self._conn.commit() + + def upsert(self, node_id: str, agent_url: str) -> None: + """Record or update a node. Called on every successful registration.""" + self._conn.execute( + """ + INSERT INTO known_nodes (node_id, agent_url, last_seen) + VALUES (?, ?, ?) + ON CONFLICT(node_id) DO UPDATE SET + agent_url = excluded.agent_url, + last_seen = excluded.last_seen + """, + (node_id, agent_url, time.time()), + ) + self._conn.commit() + + def all(self) -> list[tuple[str, str]]: + """Return all known (node_id, agent_url) pairs.""" + rows = self._conn.execute( + "SELECT node_id, agent_url FROM known_nodes ORDER BY last_seen DESC" + ).fetchall() + return [(r["node_id"], r["agent_url"]) for r in rows] + + def remove(self, node_id: str) -> None: + self._conn.execute("DELETE FROM known_nodes WHERE node_id = ?", (node_id,)) + self._conn.commit() + + def prune_stale(self, max_age_days: int = _STALE_AGE_DAYS) -> int: + """Delete nodes not seen within max_age_days. Returns count removed.""" + cutoff = time.time() - max_age_days * 86400 + cur = self._conn.execute( + "DELETE FROM known_nodes WHERE last_seen < ?", (cutoff,) + ) + self._conn.commit() + removed = cur.rowcount + if removed: + logger.info("NodeStore: pruned %d stale node(s) (>%d days old)", removed, max_age_days) + return removed + + def close(self) -> None: + self._conn.close() diff --git a/tests/test_resources/test_agent_watchdog.py b/tests/test_resources/test_agent_watchdog.py new file mode 100644 index 0000000..78b6d09 --- /dev/null +++ b/tests/test_resources/test_agent_watchdog.py @@ -0,0 +1,151 @@ +# tests/test_resources/test_agent_watchdog.py +""" +Tests for AgentSupervisor watchdog behaviour: + - restore_from_store() reloads known nodes from NodeStore on startup + - register() persists to NodeStore + - restored nodes start offline and come online after a successful poll + - NodeStore=None path is a no-op (backwards compatibility) +""" +from __future__ import annotations + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from circuitforge_core.resources.coordinator.agent_supervisor import AgentSupervisor +from circuitforge_core.resources.coordinator.lease_manager import LeaseManager +from circuitforge_core.resources.coordinator.node_store import NodeStore + + +# ── fixtures ────────────────────────────────────────────────────────────────── + +@pytest.fixture +def store(tmp_path: Path) -> NodeStore: + return NodeStore(db_path=tmp_path / "nodes.db") + + +@pytest.fixture +def supervisor(store: NodeStore) -> AgentSupervisor: + return AgentSupervisor(lease_manager=LeaseManager(), node_store=store) + + +@pytest.fixture +def supervisor_no_store() -> AgentSupervisor: + return AgentSupervisor(lease_manager=LeaseManager(), node_store=None) + + +# ── register() persists ─────────────────────────────────────────────────────── + +def test_register_persists_to_store(supervisor: AgentSupervisor, store: NodeStore) -> None: + supervisor.register("heimdall", "http://127.0.0.1:7701") + rows = store.all() + assert len(rows) == 1 + assert rows[0] == ("heimdall", "http://127.0.0.1:7701") + + +def test_register_updates_url_in_store(supervisor: AgentSupervisor, store: NodeStore) -> None: + supervisor.register("navi", "http://10.1.10.10:7701") + supervisor.register("navi", "http://10.1.10.10:9999") + rows = store.all() + assert len(rows) == 1 + assert rows[0][1] == "http://10.1.10.10:9999" + + +def test_register_without_store_does_not_crash(supervisor_no_store: AgentSupervisor) -> None: + supervisor_no_store.register("heimdall", "http://127.0.0.1:7701") + assert supervisor_no_store.get_node_info("heimdall") is not None + + +# ── restore_from_store() ────────────────────────────────────────────────────── + +def test_restore_loads_known_nodes(tmp_path: Path) -> None: + """Nodes written by a previous supervisor session are restored into a fresh one.""" + db = tmp_path / "nodes.db" + + # Session 1: register two nodes + s1 = NodeStore(db_path=db) + sup1 = AgentSupervisor(lease_manager=LeaseManager(), node_store=s1) + sup1.register("navi", "http://10.1.10.10:7701") + sup1.register("strahl", "http://10.1.10.20:7701") + + # Session 2: fresh supervisor, same DB + s2 = NodeStore(db_path=db) + sup2 = AgentSupervisor(lease_manager=LeaseManager(), node_store=s2) + restored = sup2.restore_from_store() + + assert restored == 2 + assert sup2.get_node_info("navi") is not None + assert sup2.get_node_info("strahl") is not None + + +def test_restore_marks_nodes_offline(tmp_path: Path) -> None: + """Restored nodes start offline — they haven't been polled yet.""" + db = tmp_path / "nodes.db" + + s1 = NodeStore(db_path=db) + AgentSupervisor(lease_manager=LeaseManager(), node_store=s1).register( + "navi", "http://10.1.10.10:7701" + ) + + s2 = NodeStore(db_path=db) + sup2 = AgentSupervisor(lease_manager=LeaseManager(), node_store=s2) + sup2.restore_from_store() + + assert sup2.online_agents() == {} + + +def test_restore_returns_zero_without_store() -> None: + sup = AgentSupervisor(lease_manager=LeaseManager(), node_store=None) + assert sup.restore_from_store() == 0 + + +def test_restore_skips_already_registered(tmp_path: Path) -> None: + """Nodes manually registered before restore_from_store() are not duplicated.""" + db = tmp_path / "nodes.db" + store = NodeStore(db_path=db) + store.upsert("heimdall", "http://127.0.0.1:7701") + + sup = AgentSupervisor(lease_manager=LeaseManager(), node_store=store) + sup.register("heimdall", "http://127.0.0.1:7701") # already in memory + restored = sup.restore_from_store() + + assert restored == 0 # already present, not double-counted + + +# ── restored node comes online after poll ───────────────────────────────────── + +@pytest.mark.asyncio +async def test_restored_node_comes_online_after_poll(tmp_path: Path) -> None: + """After restore, a successful poll_agent() brings the node online.""" + db = tmp_path / "nodes.db" + store = NodeStore(db_path=db) + store.upsert("navi", "http://10.1.10.10:7701") + + sup = AgentSupervisor(lease_manager=LeaseManager(), node_store=store) + sup.restore_from_store() + + # Stub poll_agent to succeed + gpu_payload = {"gpus": [{"gpu_id": 0, "name": "RTX 4000", + "vram_total_mb": 8192, "vram_used_mb": 512, "vram_free_mb": 7680}]} + resident_payload = {"residents": []} + + mock_resp_gpu = MagicMock() + mock_resp_gpu.raise_for_status = MagicMock() + mock_resp_gpu.json.return_value = gpu_payload + + mock_resp_res = MagicMock() + mock_resp_res.is_success = True + mock_resp_res.json.return_value = resident_payload + + mock_client = AsyncMock() + mock_client.get = AsyncMock(side_effect=[mock_resp_gpu, mock_resp_res]) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch("circuitforge_core.resources.coordinator.agent_supervisor.httpx.AsyncClient", + return_value=mock_client): + result = await sup.poll_agent("navi") + + assert result is True + assert "navi" in sup.online_agents() diff --git a/tests/test_resources/test_node_store.py b/tests/test_resources/test_node_store.py new file mode 100644 index 0000000..91b6e0c --- /dev/null +++ b/tests/test_resources/test_node_store.py @@ -0,0 +1,87 @@ +# tests/test_resources/test_node_store.py +"""Unit tests for NodeStore — SQLite persistence layer for known agent nodes.""" +from __future__ import annotations + +import time +from pathlib import Path + +import pytest + +from circuitforge_core.resources.coordinator.node_store import NodeStore + + +@pytest.fixture +def store(tmp_path: Path) -> NodeStore: + return NodeStore(db_path=tmp_path / "test-nodes.db") + + +def test_upsert_and_all(store: NodeStore) -> None: + store.upsert("heimdall", "http://127.0.0.1:7701") + rows = store.all() + assert len(rows) == 1 + assert rows[0] == ("heimdall", "http://127.0.0.1:7701") + + +def test_upsert_updates_url(store: NodeStore) -> None: + store.upsert("navi", "http://10.1.10.10:7701") + store.upsert("navi", "http://10.1.10.10:7702") + rows = store.all() + assert len(rows) == 1 + assert rows[0][1] == "http://10.1.10.10:7702" + + +def test_multiple_nodes(store: NodeStore) -> None: + store.upsert("heimdall", "http://127.0.0.1:7701") + store.upsert("navi", "http://10.1.10.10:7701") + store.upsert("strahl", "http://10.1.10.20:7701") + assert len(store.all()) == 3 + + +def test_remove(store: NodeStore) -> None: + store.upsert("heimdall", "http://127.0.0.1:7701") + store.upsert("navi", "http://10.1.10.10:7701") + store.remove("navi") + ids = [r[0] for r in store.all()] + assert "navi" not in ids + assert "heimdall" in ids + + +def test_prune_stale_removes_old_entries(store: NodeStore) -> None: + # Insert a node with a last_seen in the distant past + store._conn.execute( + "INSERT INTO known_nodes (node_id, agent_url, last_seen) VALUES (?, ?, ?)", + ("ghost", "http://dead:7701", time.time() - 40 * 86400), + ) + store._conn.commit() + store.upsert("live", "http://live:7701") + + removed = store.prune_stale(max_age_days=30) + assert removed == 1 + ids = [r[0] for r in store.all()] + assert "ghost" not in ids + assert "live" in ids + + +def test_prune_stale_keeps_recent(store: NodeStore) -> None: + store.upsert("recent", "http://recent:7701") + removed = store.prune_stale(max_age_days=30) + assert removed == 0 + assert len(store.all()) == 1 + + +def test_all_empty(store: NodeStore) -> None: + assert store.all() == [] + + +def test_db_persists_across_instances(tmp_path: Path) -> None: + """Data written by one NodeStore instance is visible to a new one on the same file.""" + db = tmp_path / "shared.db" + s1 = NodeStore(db_path=db) + s1.upsert("navi", "http://10.1.10.10:7701") + s1.close() + + s2 = NodeStore(db_path=db) + rows = s2.all() + assert len(rows) == 1 + assert rows[0][0] == "navi" + s2.close()