Merge pull request 'feat(orch): health probe loop + VRAM pre-flight fix' (#12) from feature/orch-llm-server into main

This commit is contained in:
pyr0ball 2026-04-02 17:24:09 -07:00
commit 749e51ccca
17 changed files with 799 additions and 29 deletions

3
.gitignore vendored
View file

@ -5,6 +5,9 @@ __pycache__/
dist/
.pytest_cache/
.superpowers/
.coverage
build/
"<MagicMock*"
# cf-orch private profiles (commit on personal/heimdall branch only)
circuitforge_core/resources/profiles/private/

View file

@ -27,4 +27,8 @@ def get_connection(db_path: Path, key: str = "") -> sqlite3.Connection:
return conn
# timeout=30: retry for up to 30s when another writer holds the lock (WAL mode
# allows concurrent readers but only one writer at a time).
return sqlite3.connect(str(db_path), timeout=30)
# check_same_thread=False: each Store is owned by exactly one request; FastAPI
# uses asyncio.to_thread() to run sync DB calls in a worker thread, crossing
# the thread boundary that sqlite3 guards by default. Since no two threads share
# the same connection, disabling the check is safe.
return sqlite3.connect(str(db_path), timeout=30, check_same_thread=False)

View file

@ -3,11 +3,12 @@ from __future__ import annotations
import logging
from typing import Any
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from circuitforge_core.resources.agent.eviction_executor import EvictionExecutor
from circuitforge_core.resources.agent.gpu_monitor import GpuMonitor
from circuitforge_core.resources.agent.service_manager import ServiceManager
logger = logging.getLogger(__name__)
@ -17,10 +18,16 @@ class EvictRequest(BaseModel):
grace_period_s: float = 5.0
class ServiceStartRequest(BaseModel):
gpu_id: int = 0
params: dict[str, str] = {}
def create_agent_app(
node_id: str,
monitor: GpuMonitor | None = None,
executor: EvictionExecutor | None = None,
service_manager: ServiceManager | None = None,
) -> FastAPI:
_monitor = monitor or GpuMonitor()
_executor = executor or EvictionExecutor()
@ -57,4 +64,38 @@ def create_agent_app(
"message": result.message,
}
@app.get("/resident-info")
def resident_info() -> dict[str, Any]:
"""Return which models are currently loaded in each running managed service."""
if service_manager is None:
return {"residents": []}
from circuitforge_core.resources.agent.service_probe import probe_all
return {"residents": probe_all(service_manager)}
if service_manager is not None:
@app.get("/services")
def list_services() -> dict:
return {"running": service_manager.list_running()}
@app.get("/services/{service}")
def service_status(service: str) -> dict:
running = service_manager.is_running(service)
url = service_manager.get_url(service) if running else None
return {"service": service, "running": running, "url": url}
@app.post("/services/{service}/start")
def start_service(service: str, req: ServiceStartRequest) -> dict:
try:
url = service_manager.start(service, req.gpu_id, req.params)
return {"service": service, "url": url, "running": True}
except (ValueError, NotImplementedError) as exc:
raise HTTPException(status_code=422, detail=str(exc))
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Failed to start {service}: {exc}")
@app.post("/services/{service}/stop")
def stop_service(service: str) -> dict:
stopped = service_manager.stop(service)
return {"service": service, "stopped": stopped}
return app

View file

@ -0,0 +1,169 @@
"""
ServiceManager start/stop Docker containers and processes for cf-orch managed services.
Container naming convention: cf-orch-{service}-{node_id}
"""
from __future__ import annotations
import os
import re
import subprocess
from collections import defaultdict
from typing import Any
from circuitforge_core.resources.profiles.schema import DockerSpec, GpuProfile, ProcessSpec
def _expand_volume(v: str) -> str:
"""Expand bash-style volume strings including ${VAR:-default} and $VAR."""
def _sub(m: re.Match) -> str: # type: ignore[type-arg]
var, default = m.group(1), m.group(2) or ""
return os.environ.get(var) or default
v = re.sub(r"\$\{(\w+)(?::-(.*?))?\}", _sub, v)
v = re.sub(r"\$(\w+)", lambda m: os.environ.get(m.group(1), m.group(0)), v)
return v
class ServiceManager:
def __init__(
self,
node_id: str,
profile: GpuProfile,
advertise_host: str = "127.0.0.1",
) -> None:
self.node_id = node_id
self.profile = profile
self.advertise_host = advertise_host
self._procs: dict[str, Any] = {}
def container_name(self, service: str) -> str:
return f"cf-orch-{service}-{self.node_id}"
def _get_spec(self, service: str) -> DockerSpec | ProcessSpec | None:
svc = self.profile.services.get(service)
if svc is None:
return None
return svc.managed
def is_running(self, service: str) -> bool:
spec = self._get_spec(service)
if spec is None:
return False
if isinstance(spec, DockerSpec):
try:
result = subprocess.run(
[
"docker",
"inspect",
"--format",
"{{.State.Running}}",
self.container_name(service),
],
capture_output=True,
text=True,
check=True,
)
return result.stdout.strip() == "true"
except subprocess.CalledProcessError:
return False
if isinstance(spec, ProcessSpec):
proc = self._procs.get(service)
if proc is None or proc.poll() is not None:
return False
import socket
try:
with socket.create_connection(("127.0.0.1", spec.host_port), timeout=1):
return True
except OSError:
return False
return False
def start(self, service: str, gpu_id: int, params: dict[str, str]) -> str:
spec = self._get_spec(service)
if spec is None:
raise ValueError(f"Service {service!r} not in profile or has no managed spec")
if self.is_running(service):
return f"http://{self.advertise_host}:{spec.host_port}"
if isinstance(spec, DockerSpec):
expanded_volumes = [_expand_volume(v) for v in spec.volumes]
filler: dict[str, str] = defaultdict(str, params)
expanded_command = spec.command_template.format_map(filler).split()
cmd = [
"docker", "run", "-d", "--rm",
"--name", self.container_name(service),
"--runtime", spec.runtime,
"--gpus", f"device={gpu_id}",
"--ipc", spec.ipc,
"-p", f"{spec.host_port}:{spec.port}",
]
for vol in expanded_volumes:
cmd += ["-v", vol]
for key, val in spec.env.items():
cmd += ["-e", f"{key}={val}"]
cmd.append(spec.image)
cmd.extend(expanded_command)
subprocess.run(cmd, check=True, capture_output=True, text=True)
return f"http://{self.advertise_host}:{spec.host_port}"
if isinstance(spec, ProcessSpec):
import shlex
import subprocess as _sp
filler = defaultdict(str, params)
filler.setdefault("port", str(spec.port))
filler.setdefault("gpu_id", str(gpu_id))
args_expanded = spec.args_template.format_map(filler).split()
cmd = [spec.exec_path] + args_expanded
env = {**__import__("os").environ}
proc = _sp.Popen(
cmd,
cwd=spec.cwd or None,
env=env,
stdout=_sp.DEVNULL,
stderr=_sp.DEVNULL,
)
self._procs[service] = proc
return f"http://{self.advertise_host}:{spec.host_port}"
raise NotImplementedError(f"Unknown spec type: {type(spec)}")
def stop(self, service: str) -> bool:
spec = self._get_spec(service)
if spec is None:
return False
if isinstance(spec, DockerSpec):
try:
subprocess.run(
["docker", "stop", self.container_name(service)],
check=True,
capture_output=True,
text=True,
)
return True
except subprocess.CalledProcessError:
return False
if isinstance(spec, ProcessSpec):
proc = self._procs.pop(service, None)
if proc is not None:
proc.terminate()
try:
proc.wait(timeout=10)
except Exception:
proc.kill()
return True
return False
def list_running(self) -> list[str]:
return [svc for svc in self.profile.services if self.is_running(svc)]
def get_url(self, service: str) -> str | None:
spec = self._get_spec(service)
if spec is None or not self.is_running(service):
return None
return f"http://{self.advertise_host}:{spec.host_port}"

View file

@ -0,0 +1,123 @@
"""
Probe running services to detect which models are currently loaded in VRAM.
Two probe strategies run together:
1. Well-known ports always checked, regardless of who started the service.
Catches ollama, vLLM, etc. running outside cf-orch management.
2. Managed services services cf-orch started via ServiceManager.
Checked on their configured host_port, deduplicates with well-known results.
Each service exposes a different introspection API:
- vllm: GET /v1/models {"data": [{"id": "<model-name>"}]}
- ollama: GET /api/ps {"models": [{"name": "<model>", "size_vram": <bytes>}]}
ollama can have multiple models loaded simultaneously; each is reported as a
separate entry so the dashboard shows per-model residency.
The probe is best-effort: a timeout or connection refusal means model_name=None
but the service is still reported as resident.
"""
from __future__ import annotations
import json
import logging
import urllib.request
from typing import Any
from circuitforge_core.resources.profiles.schema import DockerSpec
logger = logging.getLogger(__name__)
_PROBE_TIMEOUT_S = 2.0
# Well-known service ports probed on every heartbeat.
# key → (service_name, prober_key)
_WELL_KNOWN_PORTS: dict[int, str] = {
11434: "ollama",
8000: "vllm",
8080: "vllm", # common alt vLLM port
}
def _fetch_json(url: str) -> dict[str, Any] | None:
"""GET a URL and parse JSON; returns None on any error."""
try:
with urllib.request.urlopen(url, timeout=_PROBE_TIMEOUT_S) as resp:
return json.loads(resp.read())
except Exception as exc:
logger.debug("Probe %s: %s", url, exc)
return None
def _probe_vllm(port: int) -> list[str]:
data = _fetch_json(f"http://127.0.0.1:{port}/v1/models")
if data and data.get("data"):
return [m["id"] for m in data["data"] if m.get("id")]
return []
def _probe_ollama(port: int) -> list[str]:
# /api/ps lists models currently *loaded in memory*, not just downloaded.
data = _fetch_json(f"http://127.0.0.1:{port}/api/ps")
if data and data.get("models"):
return [m["name"] for m in data["models"] if m.get("name")]
return []
_PROBERS: dict[str, Any] = {
"vllm": _probe_vllm,
"ollama": _probe_ollama,
}
def probe_all(service_manager: Any) -> list[dict[str, Any]]:
"""
Probe all services both well-known ports and cf-orch managed services.
Returns a list of dicts: [{"service": str, "model_name": str | None}].
Multiple loaded models in one service (e.g. two ollama models) each get
their own entry, disambiguated as "ollama/0", "ollama/1", etc.
"""
results: list[dict[str, Any]] = []
seen_ports: set[int] = set()
# ── 1. Well-known ports ──────────────────────────────────────────
for port, service in _WELL_KNOWN_PORTS.items():
prober = _PROBERS.get(service)
if prober is None:
continue
models = prober(port)
if not models:
continue # nothing on this port right now
seen_ports.add(port)
if len(models) == 1:
results.append({"service": service, "model_name": models[0]})
else:
for i, model in enumerate(models):
results.append({"service": f"{service}/{i}", "model_name": model})
# ── 2. Managed services (cf-orch started) ───────────────────────
if service_manager is not None:
for service in service_manager.list_running():
spec = service_manager._get_spec(service)
if not isinstance(spec, DockerSpec):
continue
if spec.host_port in seen_ports:
continue # already captured by well-known probe
prober = _PROBERS.get(service)
if prober is None:
results.append({"service": service, "model_name": None})
continue
models = prober(spec.host_port)
seen_ports.add(spec.host_port)
if not models:
results.append({"service": service, "model_name": None})
elif len(models) == 1:
results.append({"service": service, "model_name": models[0]})
else:
for i, model in enumerate(models):
results.append({"service": f"{service}/{i}", "model_name": model})
return results

View file

@ -1,9 +1,14 @@
from __future__ import annotations
import logging
import time
import urllib.request
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
@ -17,6 +22,54 @@ from circuitforge_core.resources.coordinator.service_registry import ServiceRegi
_DASHBOARD_HTML = (Path(__file__).parent / "dashboard.html").read_text()
_PROBE_INTERVAL_S = 5.0 # how often to poll starting instances
_PROBE_TIMEOUT_S = 300.0 # give up and mark stopped after this many seconds
async def _run_instance_probe_loop(service_registry: ServiceRegistry) -> None:
"""
Background loop: transition 'starting' instances to 'running' once their
/health endpoint responds, or to 'stopped' after PROBE_TIMEOUT_S.
"""
import asyncio
start_times: dict[str, float] = {} # instance key → time first seen as starting
while True:
await asyncio.sleep(_PROBE_INTERVAL_S)
now = time.time()
for inst in service_registry.all_instances():
if inst.state != "starting":
start_times.pop(f"{inst.service}:{inst.node_id}:{inst.gpu_id}", None)
continue
key = f"{inst.service}:{inst.node_id}:{inst.gpu_id}"
start_times.setdefault(key, now)
healthy = False
if inst.url:
try:
with urllib.request.urlopen(
inst.url.rstrip("/") + "/health", timeout=2.0
) as resp:
healthy = resp.status == 200
except Exception:
pass
if healthy:
service_registry.upsert_instance(
service=inst.service, node_id=inst.node_id, gpu_id=inst.gpu_id,
state="running", model=inst.model, url=inst.url,
)
start_times.pop(key, None)
logger.info("Instance %s/%s gpu=%s transitioned to running", inst.service, inst.node_id, inst.gpu_id)
elif now - start_times[key] > _PROBE_TIMEOUT_S:
service_registry.upsert_instance(
service=inst.service, node_id=inst.node_id, gpu_id=inst.gpu_id,
state="stopped", model=inst.model, url=inst.url,
)
start_times.pop(key, None)
logger.warning("Instance %s/%s gpu=%s timed out in starting state — marked stopped", inst.service, inst.node_id, inst.gpu_id)
class LeaseRequest(BaseModel):
node_id: str
@ -61,10 +114,12 @@ def create_coordinator_app(
@asynccontextmanager
async def _lifespan(app: FastAPI): # type: ignore[type-arg]
import asyncio
task = asyncio.create_task(agent_supervisor.run_heartbeat_loop())
heartbeat_task = asyncio.create_task(agent_supervisor.run_heartbeat_loop())
probe_task = asyncio.create_task(_run_instance_probe_loop(service_registry))
yield
agent_supervisor.stop()
task.cancel()
heartbeat_task.cancel()
probe_task.cancel()
app = FastAPI(title="cf-orch-coordinator", lifespan=_lifespan)
@ -227,12 +282,12 @@ def create_coordinator_app(
service_max_mb = svc.max_mb
break
# Filter candidates by VRAM headroom — skip models where free VRAM
# is less than half of the service's max_mb ceiling.
if service_max_mb > 0 and free_mb < service_max_mb // 2:
# Filter candidates by VRAM headroom — require free VRAM >= service ceiling
# so the model can actually load without competing for VRAM with other processes.
if service_max_mb > 0 and free_mb < service_max_mb:
raise HTTPException(
503,
detail=f"Insufficient VRAM on gpu {req.gpu_id}: {free_mb}MB free, need at least {service_max_mb // 2}MB",
detail=f"Insufficient VRAM on gpu {req.gpu_id}: {free_mb}MB free, need {service_max_mb}MB",
)
last_error: str = ""

View file

@ -42,7 +42,7 @@ def select_node(
for gpu in record.gpus:
warm = f"{node_id}:{service}" in resident_keys
effective = gpu.vram_free_mb + (_WARM_BONUS_MB if warm else 0)
can_fit = gpu.vram_free_mb >= service_max_mb // 2
can_fit = gpu.vram_free_mb >= service_max_mb
candidates.append(_Scored(
node_id=node_id,
gpu_id=gpu.gpu_id,

View file

@ -0,0 +1,137 @@
"""Generic OpenAI-compatible inference server for HuggingFace causal LMs."""
from __future__ import annotations
import argparse
import time
import uuid
from contextlib import asynccontextmanager
from typing import Any
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
_model: Any = None
_tokenizer: Any = None
_model_id: str = ""
_device: str = "cpu"
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
app = FastAPI(lifespan=lifespan)
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str | None = None
messages: list[Message]
max_tokens: int | None = 512
temperature: float | None = 0.7
stream: bool | None = False
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "ok", "model": _model_id}
@app.get("/v1/models")
def list_models() -> dict[str, Any]:
return {
"object": "list",
"data": [{"id": _model_id, "object": "model", "owned_by": "cf-orch"}],
}
@app.post("/v1/chat/completions")
def chat_completions(req: ChatRequest) -> dict[str, Any]:
if _model is None:
raise HTTPException(503, detail="Model not loaded")
if req.stream:
raise HTTPException(501, detail="Streaming not supported")
conversation = [{"role": m.role, "content": m.content} for m in req.messages]
try:
encoded = _tokenizer.apply_chat_template(
conversation,
return_tensors="pt",
add_generation_prompt=True,
)
# transformers 5.x returns BatchEncoding; 4.x returned a bare tensor
input_ids = (encoded.input_ids if hasattr(encoded, "input_ids") else encoded).to(_device)
except Exception as exc:
raise HTTPException(500, detail=f"Tokenisation failed: {exc}")
max_new = req.max_tokens or 512
temp = req.temperature if req.temperature is not None else 0.7
gen_kwargs: dict[str, Any] = {
"max_new_tokens": max_new,
"do_sample": temp > 0,
"pad_token_id": _tokenizer.eos_token_id,
}
if temp > 0:
gen_kwargs["temperature"] = temp
with torch.inference_mode():
output_ids = _model.generate(input_ids, **gen_kwargs)
new_tokens = output_ids[0][input_ids.shape[-1]:]
reply = _tokenizer.decode(new_tokens, skip_special_tokens=True)
return {
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"object": "chat.completion",
"created": int(time.time()),
"model": _model_id,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": reply},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": input_ids.shape[-1],
"completion_tokens": len(new_tokens),
"total_tokens": input_ids.shape[-1] + len(new_tokens),
},
}
def _load_model(model_path: str, gpu_id: int) -> None:
global _model, _tokenizer, _model_id, _device
_device = f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu"
_model_id = model_path
_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
_model = AutoModelForCausalLM.from_pretrained(
model_path,
dtype=torch.float16 if "cuda" in _device else torch.float32,
device_map={"": _device},
trust_remote_code=True,
)
_model.eval()
def main() -> None:
parser = argparse.ArgumentParser(description="cf-orch generic LLM inference server")
parser.add_argument("--model", required=True)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--gpu-id", type=int, default=0)
args = parser.parse_args()
_load_model(args.model, args.gpu_id)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
if __name__ == "__main__":
main()

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import time
import uuid
from dataclasses import dataclass, field
from typing import Optional
@dataclass(frozen=True)
@ -48,6 +49,15 @@ class GpuInfo:
vram_free_mb: int
@dataclass(frozen=True)
class ResidentAllocation:
"""A model that is loaded and warm in VRAM but not actively serving a request."""
service: str
node_id: str
model_name: Optional[str] # None if service is running but model probe failed
first_seen: float = field(default_factory=time.time)
@dataclass
class NodeInfo:
node_id: str

View file

@ -4,9 +4,16 @@ vram_total_mb: 16384
eviction_timeout_s: 10.0
services:
vllm:
max_mb: 12288
max_mb: 9000
priority: 1
idle_stop_after_s: 600
managed:
type: process
exec_path: "/devl/miniconda3/envs/cf/bin/python"
args_template: "-m circuitforge_core.resources.inference.llm_server --model /Library/Assets/LLM/vllm/models/{model} --port {port} --gpu-id {gpu_id}"
port: 8000
host_port: 8000
cwd: "/Library/Development/CircuitForge/circuitforge-core"
ollama:
max_mb: 12288
priority: 1

View file

@ -4,9 +4,16 @@ vram_total_mb: 24576
eviction_timeout_s: 10.0
services:
vllm:
max_mb: 20480
max_mb: 9000
priority: 1
idle_stop_after_s: 600
managed:
type: process
exec_path: "/devl/miniconda3/envs/cf/bin/python"
args_template: "-m circuitforge_core.resources.inference.llm_server --model /Library/Assets/LLM/vllm/models/{model} --port {port} --gpu-id {gpu_id}"
port: 8000
host_port: 8000
cwd: "/Library/Development/CircuitForge/circuitforge-core"
ollama:
max_mb: 18432
priority: 1

View file

@ -4,19 +4,16 @@ vram_total_mb: 6144
eviction_timeout_s: 10.0
services:
vllm:
max_mb: 4096
max_mb: 5500
priority: 1
idle_stop_after_s: 600
managed:
type: docker
image: "vllm/vllm-openai:v0.9.2"
type: process
exec_path: "/devl/miniconda3/envs/cf/bin/python"
args_template: "-m circuitforge_core.resources.inference.llm_server --model /Library/Assets/LLM/vllm/models/{model} --port {port} --gpu-id {gpu_id}"
port: 8000
host_port: 8000
command_template: "--model /models/{model} --trust-remote-code --max-model-len {max_model_len} --gpu-memory-utilization {gpu_mem_util} --enforce-eager --max-num-seqs 8"
volumes:
- "${VLLM_MODELS_DIR:-/Library/Assets/LLM/vllm/models}:/models"
runtime: nvidia
ipc: host
cwd: "/Library/Development/CircuitForge/circuitforge-core"
ollama:
max_mb: 3584
priority: 1

View file

@ -4,19 +4,16 @@ vram_total_mb: 8192
eviction_timeout_s: 10.0
services:
vllm:
max_mb: 5120
max_mb: 6500
priority: 1
idle_stop_after_s: 600
managed:
type: docker
image: "vllm/vllm-openai:v0.9.2"
type: process
exec_path: "/devl/miniconda3/envs/cf/bin/python"
args_template: "-m circuitforge_core.resources.inference.llm_server --model /Library/Assets/LLM/vllm/models/{model} --port {port} --gpu-id {gpu_id}"
port: 8000
host_port: 8000
command_template: "--model /models/{model} --trust-remote-code --max-model-len {max_model_len} --gpu-memory-utilization {gpu_mem_util} --enforce-eager --max-num-seqs 8"
volumes:
- "${VLLM_MODELS_DIR:-/Library/Assets/LLM/vllm/models}:/models"
runtime: nvidia
ipc: host
cwd: "/Library/Development/CircuitForge/circuitforge-core"
ollama:
max_mb: 4096
priority: 1

View file

@ -149,8 +149,8 @@ def test_single_gpu_8gb_profile_has_idle_stop_after_s():
def test_ensure_service_returns_503_when_vram_too_low():
"""VRAM pre-flight guard fires before any HTTP request when free VRAM < max_mb // 2."""
# vllm max_mb = 5120 → threshold = 2560 MB; 100 MB free triggers 503.
"""VRAM pre-flight guard fires before any HTTP request when free VRAM < service max_mb."""
# Threshold = full max_mb (not half); 100 MB free on any profile triggers 503.
lease_manager = LeaseManager()
lease_manager.register_gpu("low-vram-node", 0, 512)
profile_registry = ProfileRegistry()

View file

@ -54,3 +54,29 @@ def test_returns_none_when_no_agents():
registry = ProfileRegistry()
result = select_node({}, "vllm", registry, resident_keys=set())
assert result is None
def test_prefers_node_that_fully_fits_service_over_one_that_does_not():
"""can_fit requires free_mb >= service max_mb (full ceiling, not half).
9500 MB guarantees above all profile ceilings (max is 9000); 1000 MB is below all.
"""
agents = {
"a": _make_agent("a", free_mb=1000),
"b": _make_agent("b", free_mb=9500),
}
registry = ProfileRegistry()
result = select_node(agents, "vllm", registry, resident_keys=set())
# "b" is the only node in the preferred (can_fit) pool
assert result == ("b", 0)
def test_falls_back_to_best_effort_when_no_node_fully_fits():
"""When nothing can_fit, select_node returns the best-VRAM node as fallback."""
agents = {
"a": _make_agent("a", free_mb=1000),
"b": _make_agent("b", free_mb=2000),
}
registry = ProfileRegistry()
# Neither has enough free VRAM; fallback picks highest effective_free_mb
result = select_node(agents, "vllm", registry, resident_keys=set())
assert result == ("b", 0)

View file

@ -0,0 +1,194 @@
"""Tests for ServiceManager ProcessSpec support."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from circuitforge_core.resources.agent.service_manager import ServiceManager
from circuitforge_core.resources.profiles.schema import (
GpuProfile,
ProcessSpec,
ServiceProfile,
)
def _make_profile(args_template: str = "--port {port} --gpu-id {gpu_id}") -> GpuProfile:
return GpuProfile(
schema_version=1,
name="test",
vram_total_mb=8192,
services={
"vllm": ServiceProfile(
max_mb=5120,
priority=1,
managed=ProcessSpec(
exec_path="/usr/bin/python",
args_template=args_template,
port=8000,
host_port=8000,
cwd="/tmp",
),
),
"no_managed": ServiceProfile(max_mb=1024, priority=2),
},
)
@pytest.fixture
def manager():
return ServiceManager(node_id="test-node", profile=_make_profile(), advertise_host="127.0.0.1")
# ---------------------------------------------------------------------------
# is_running
# ---------------------------------------------------------------------------
def test_is_running_returns_false_when_no_proc(manager):
assert manager.is_running("vllm") is False
def test_is_running_returns_false_when_proc_exited(manager):
mock_proc = MagicMock()
mock_proc.poll.return_value = 1 # exited
manager._procs["vllm"] = mock_proc
assert manager.is_running("vllm") is False
def test_is_running_returns_false_when_port_not_listening(manager):
mock_proc = MagicMock()
mock_proc.poll.return_value = None # still running
manager._procs["vllm"] = mock_proc
with patch("socket.create_connection", side_effect=OSError("refused")):
assert manager.is_running("vllm") is False
def test_is_running_returns_true_when_proc_alive_and_port_open(manager):
mock_proc = MagicMock()
mock_proc.poll.return_value = None # still running
manager._procs["vllm"] = mock_proc
mock_socket = MagicMock()
mock_socket.__enter__ = MagicMock(return_value=mock_socket)
mock_socket.__exit__ = MagicMock(return_value=False)
with patch("socket.create_connection", return_value=mock_socket):
assert manager.is_running("vllm") is True
def test_is_running_unknown_service_returns_false(manager):
assert manager.is_running("nonexistent") is False
def test_is_running_no_managed_spec_returns_false(manager):
assert manager.is_running("no_managed") is False
# ---------------------------------------------------------------------------
# start
# ---------------------------------------------------------------------------
def test_start_launches_process_and_returns_url(manager):
with patch("subprocess.Popen") as mock_popen, \
patch.object(manager, "is_running", return_value=False):
mock_popen.return_value = MagicMock()
url = manager.start("vllm", gpu_id=0, params={"model": "mymodel"})
assert url == "http://127.0.0.1:8000"
mock_popen.assert_called_once()
call_args = mock_popen.call_args
cmd = call_args[0][0]
assert cmd[0] == "/usr/bin/python"
assert "--port" in cmd
assert "8000" in cmd
assert "--gpu-id" in cmd
assert "0" in cmd
def test_start_returns_url_immediately_when_already_running(manager):
with patch.object(manager, "is_running", return_value=True):
with patch("subprocess.Popen") as mock_popen:
url = manager.start("vllm", gpu_id=0, params={})
assert url == "http://127.0.0.1:8000"
mock_popen.assert_not_called()
def test_start_raises_for_unknown_service(manager):
with pytest.raises(ValueError, match="not in profile"):
manager.start("nonexistent", gpu_id=0, params={})
def test_start_stores_proc_in_procs(manager):
mock_proc = MagicMock()
with patch("subprocess.Popen", return_value=mock_proc), \
patch.object(manager, "is_running", return_value=False):
manager.start("vllm", gpu_id=0, params={})
assert manager._procs["vllm"] is mock_proc
# ---------------------------------------------------------------------------
# stop
# ---------------------------------------------------------------------------
def test_stop_terminates_running_process(manager):
mock_proc = MagicMock()
manager._procs["vllm"] = mock_proc
result = manager.stop("vllm")
assert result is True
mock_proc.terminate.assert_called_once()
mock_proc.wait.assert_called_once()
assert "vllm" not in manager._procs
def test_stop_kills_process_that_wont_terminate(manager):
mock_proc = MagicMock()
mock_proc.wait.side_effect = Exception("timeout")
manager._procs["vllm"] = mock_proc
result = manager.stop("vllm")
assert result is True
mock_proc.kill.assert_called_once()
def test_stop_returns_true_when_no_proc_tracked(manager):
# No proc in _procs — still returns True (idempotent stop)
result = manager.stop("vllm")
assert result is True
def test_stop_returns_false_for_unknown_service(manager):
result = manager.stop("nonexistent")
assert result is False
# ---------------------------------------------------------------------------
# list_running / get_url
# ---------------------------------------------------------------------------
def test_list_running_returns_running_services(manager):
def _is_running(svc: str) -> bool:
return svc == "vllm"
with patch.object(manager, "is_running", side_effect=_is_running):
running = manager.list_running()
assert running == ["vllm"]
def test_get_url_returns_none_when_not_running(manager):
with patch.object(manager, "is_running", return_value=False):
assert manager.get_url("vllm") is None
def test_get_url_returns_url_when_running(manager):
with patch.object(manager, "is_running", return_value=True):
assert manager.get_url("vllm") == "http://127.0.0.1:8000"