diff --git a/circuitforge_core/resources/coordinator/node_selector.py b/circuitforge_core/resources/coordinator/node_selector.py new file mode 100644 index 0000000..665cbb5 --- /dev/null +++ b/circuitforge_core/resources/coordinator/node_selector.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from circuitforge_core.resources.coordinator.agent_supervisor import AgentRecord + from circuitforge_core.resources.coordinator.profile_registry import ProfileRegistry + +_WARM_BONUS_MB = 1000 + + +@dataclass +class _Scored: + node_id: str + gpu_id: int + vram_free_mb: int + effective_free_mb: int + can_fit: bool + warm: bool + + +def select_node( + agents: "dict[str, AgentRecord]", + service: str, + profile_registry: "ProfileRegistry", + resident_keys: set[str], +) -> tuple[str, int] | None: + """ + Pick the best (node_id, gpu_id) for the requested service. + Warm nodes (service already running) get priority, then sorted by free VRAM. + Returns None if no suitable node exists. + """ + candidates: list[_Scored] = [] + for node_id, record in agents.items(): + if not record.online: + continue + service_max_mb = _find_service_max_mb(service, profile_registry) + if service_max_mb is None: + continue + for gpu in record.gpus: + warm = f"{node_id}:{service}" in resident_keys + effective = gpu.vram_free_mb + (_WARM_BONUS_MB if warm else 0) + can_fit = gpu.vram_free_mb >= service_max_mb // 2 + candidates.append(_Scored( + node_id=node_id, + gpu_id=gpu.gpu_id, + vram_free_mb=gpu.vram_free_mb, + effective_free_mb=effective, + can_fit=can_fit, + warm=warm, + )) + if not candidates: + return None + # Warm nodes are always eligible (they already have the service resident). + # Cold nodes must pass the can_fit threshold. If no node passes either + # criterion, fall back to the full candidate set. + preferred = [c for c in candidates if c.warm or c.can_fit] + pool = preferred if preferred else candidates + # Warm nodes take priority; within the same warmth tier, prefer more free VRAM. + best = max(pool, key=lambda c: (c.warm, c.effective_free_mb)) + return best.node_id, best.gpu_id + + +def _find_service_max_mb(service: str, profile_registry: "ProfileRegistry") -> int | None: + for profile in profile_registry.list_public(): + svc = profile.services.get(service) + if svc is not None: + return svc.max_mb + return None diff --git a/tests/test_resources/test_node_selector.py b/tests/test_resources/test_node_selector.py new file mode 100644 index 0000000..9e18a3a --- /dev/null +++ b/tests/test_resources/test_node_selector.py @@ -0,0 +1,56 @@ +import pytest +from circuitforge_core.resources.coordinator.node_selector import select_node +from circuitforge_core.resources.coordinator.agent_supervisor import AgentRecord +from circuitforge_core.resources.models import GpuInfo +from circuitforge_core.resources.coordinator.profile_registry import ProfileRegistry + + +def _make_agent(node_id: str, free_mb: int, online: bool = True) -> AgentRecord: + r = AgentRecord(node_id=node_id, agent_url=f"http://{node_id}:7701") + r.gpus = [GpuInfo(gpu_id=0, name="RTX", vram_total_mb=8192, + vram_used_mb=8192 - free_mb, vram_free_mb=free_mb)] + r.online = online + return r + + +def test_selects_node_with_most_free_vram(): + agents = { + "a": _make_agent("a", free_mb=2000), + "b": _make_agent("b", free_mb=6000), + } + registry = ProfileRegistry() + result = select_node(agents, "vllm", registry, resident_keys=set()) + assert result == ("b", 0) + + +def test_prefers_warm_node_even_with_less_free_vram(): + agents = { + "a": _make_agent("a", free_mb=2000), + "b": _make_agent("b", free_mb=6000), + } + registry = ProfileRegistry() + result = select_node(agents, "vllm", registry, resident_keys={"a:vllm"}) + assert result == ("a", 0) + + +def test_excludes_offline_nodes(): + agents = { + "a": _make_agent("a", free_mb=8000, online=False), + "b": _make_agent("b", free_mb=2000, online=True), + } + registry = ProfileRegistry() + result = select_node(agents, "vllm", registry, resident_keys=set()) + assert result == ("b", 0) + + +def test_returns_none_when_no_node_has_profile_for_service(): + agents = {"a": _make_agent("a", free_mb=8000)} + registry = ProfileRegistry() + result = select_node(agents, "cf-nonexistent-service", registry, resident_keys=set()) + assert result is None + + +def test_returns_none_when_no_agents(): + registry = ProfileRegistry() + result = select_node({}, "vllm", registry, resident_keys=set()) + assert result is None