feat: add profile endpoint and GPU service assignment with compatibility check

This commit is contained in:
pyr0ball 2026-05-05 20:33:41 -07:00
parent fd8cb622a1
commit f952ec8971

View file

@ -198,3 +198,150 @@ def test_get_profile_500_on_malformed_yaml(client, tmp_path):
r = client.get("/api/nodes-mgmt/nodes/bad/profile") r = client.get("/api/nodes-mgmt/nodes/bad/profile")
assert r.status_code == 500 assert r.status_code == 500
# ── POST /api/nodes-mgmt/nodes/{node_id}/gpu/{gpu_id}/services ─────────────────
import os as _os
_BASE_PROFILE = {
"services": {
"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "priority": 1,
"catalog": {"llama3": {"vram_mb": 6144, "path": "/m/llama3",
"description": "", "multi_gpu": False, "env": {}}}},
"ollama": {"min_compute_cap": 0.0, "max_mb": 2048, "priority": 2, "catalog": {}},
},
"nodes": {
"heimdall": {
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6,
"services": [], "role": "primary", "card": "RTX 3090",
"always_on": True}],
"agent_url": "http://10.1.10.71:7701",
}
}
}
def _setup_profile(tmp_path, profile=None):
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {
"coordinator_url": "http://fake-coord:7700",
"profiles_dir": str(profiles_dir),
})
_write_profile(profiles_dir, "heimdall", profile or _BASE_PROFILE)
return profiles_dir
def test_update_services_compatible_writes_and_reloads(client, tmp_path):
profiles_dir = _setup_profile(tmp_path)
mock_reload = MagicMock()
mock_reload.status_code = 200
with patch("httpx.post", return_value=mock_reload):
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
json={"services": ["cf-text"]},
)
assert r.status_code == 200
data = r.json()
assert data["ok"] is True
assert data["reloaded"] is True
saved = yaml.safe_load((profiles_dir / "heimdall.yaml").read_text())
assert saved["nodes"]["heimdall"]["gpus"][0]["services"] == ["cf-text"]
def test_update_services_atomic_write_uses_tmp_file(client, tmp_path):
"""YAML must be written to .tmp then renamed — never written directly."""
profiles_dir = _setup_profile(tmp_path)
renamed_pairs: list[tuple] = []
original_replace = _os.replace
def capture(src, dst):
renamed_pairs.append((str(src), str(dst)))
original_replace(src, dst)
with patch("os.replace", side_effect=capture), \
patch("httpx.post", return_value=MagicMock(status_code=200)):
client.post(
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
json={"services": ["ollama"]},
)
assert any(src.endswith(".tmp") for src, dst in renamed_pairs), \
"Expected atomic write via .tmp rename"
def test_update_services_incompatible_compute_cap_returns_422(client, tmp_path):
low_cap_profile = {
**_BASE_PROFILE,
"nodes": {
"heimdall": {
"gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 6.0,
"services": [], "role": "p", "card": "GTX 1080",
"always_on": False}],
"agent_url": "http://10.1.10.71:7701",
}
}
}
_setup_profile(tmp_path, low_cap_profile)
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
json={"services": ["cf-text"]},
)
assert r.status_code == 422
assert "compute_cap" in r.json()["detail"]
def test_update_services_insufficient_vram_returns_422(client, tmp_path):
tiny_vram_profile = {
**_BASE_PROFILE,
"nodes": {
"heimdall": {
"gpus": [{"id": 0, "vram_mb": 512, "compute_cap": 8.6,
"services": [], "role": "p", "card": "old",
"always_on": False}],
"agent_url": "http://10.1.10.71:7701",
}
}
}
_setup_profile(tmp_path, tiny_vram_profile)
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
json={"services": ["cf-text"]},
)
assert r.status_code == 422
assert "VRAM" in r.json()["detail"]
def test_update_services_unknown_service_returns_422(client, tmp_path):
_setup_profile(tmp_path)
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
json={"services": ["not-a-real-service"]},
)
assert r.status_code == 422
def test_update_services_reload_failure_returns_reloaded_false(client, tmp_path):
"""YAML saved but coordinator reload fails — ok: true, reloaded: false."""
_setup_profile(tmp_path)
mock_reload = MagicMock()
mock_reload.status_code = 500
with patch("httpx.post", return_value=mock_reload):
r = client.post(
"/api/nodes-mgmt/nodes/heimdall/gpu/0/services",
json={"services": ["ollama"]},
)
assert r.status_code == 200
data = r.json()
assert data["ok"] is True
assert data["reloaded"] is False