diff --git a/tests/test_nodes.py b/tests/test_nodes.py index a475286..858724a 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -198,3 +198,150 @@ def test_get_profile_500_on_malformed_yaml(client, tmp_path): r = client.get("/api/nodes-mgmt/nodes/bad/profile") assert r.status_code == 500 + + +# ── POST /api/nodes-mgmt/nodes/{node_id}/gpu/{gpu_id}/services ───────────────── + +import os as _os + +_BASE_PROFILE = { + "services": { + "cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "priority": 1, + "catalog": {"llama3": {"vram_mb": 6144, "path": "/m/llama3", + "description": "", "multi_gpu": False, "env": {}}}}, + "ollama": {"min_compute_cap": 0.0, "max_mb": 2048, "priority": 2, "catalog": {}}, + }, + "nodes": { + "heimdall": { + "gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 8.6, + "services": [], "role": "primary", "card": "RTX 3090", + "always_on": True}], + "agent_url": "http://10.1.10.71:7701", + } + } +} + + +def _setup_profile(tmp_path, profile=None): + profiles_dir = tmp_path / "profiles" + _write_config(tmp_path, { + "coordinator_url": "http://fake-coord:7700", + "profiles_dir": str(profiles_dir), + }) + _write_profile(profiles_dir, "heimdall", profile or _BASE_PROFILE) + return profiles_dir + + +def test_update_services_compatible_writes_and_reloads(client, tmp_path): + profiles_dir = _setup_profile(tmp_path) + + mock_reload = MagicMock() + mock_reload.status_code = 200 + + with patch("httpx.post", return_value=mock_reload): + r = client.post( + "/api/nodes-mgmt/nodes/heimdall/gpu/0/services", + json={"services": ["cf-text"]}, + ) + + assert r.status_code == 200 + data = r.json() + assert data["ok"] is True + assert data["reloaded"] is True + + saved = yaml.safe_load((profiles_dir / "heimdall.yaml").read_text()) + assert saved["nodes"]["heimdall"]["gpus"][0]["services"] == ["cf-text"] + + +def test_update_services_atomic_write_uses_tmp_file(client, tmp_path): + """YAML must be written to .tmp then renamed — never written directly.""" + profiles_dir = _setup_profile(tmp_path) + renamed_pairs: list[tuple] = [] + + original_replace = _os.replace + + def capture(src, dst): + renamed_pairs.append((str(src), str(dst))) + original_replace(src, dst) + + with patch("os.replace", side_effect=capture), \ + patch("httpx.post", return_value=MagicMock(status_code=200)): + client.post( + "/api/nodes-mgmt/nodes/heimdall/gpu/0/services", + json={"services": ["ollama"]}, + ) + + assert any(src.endswith(".tmp") for src, dst in renamed_pairs), \ + "Expected atomic write via .tmp rename" + + +def test_update_services_incompatible_compute_cap_returns_422(client, tmp_path): + low_cap_profile = { + **_BASE_PROFILE, + "nodes": { + "heimdall": { + "gpus": [{"id": 0, "vram_mb": 24576, "compute_cap": 6.0, + "services": [], "role": "p", "card": "GTX 1080", + "always_on": False}], + "agent_url": "http://10.1.10.71:7701", + } + } + } + _setup_profile(tmp_path, low_cap_profile) + + r = client.post( + "/api/nodes-mgmt/nodes/heimdall/gpu/0/services", + json={"services": ["cf-text"]}, + ) + assert r.status_code == 422 + assert "compute_cap" in r.json()["detail"] + + +def test_update_services_insufficient_vram_returns_422(client, tmp_path): + tiny_vram_profile = { + **_BASE_PROFILE, + "nodes": { + "heimdall": { + "gpus": [{"id": 0, "vram_mb": 512, "compute_cap": 8.6, + "services": [], "role": "p", "card": "old", + "always_on": False}], + "agent_url": "http://10.1.10.71:7701", + } + } + } + _setup_profile(tmp_path, tiny_vram_profile) + + r = client.post( + "/api/nodes-mgmt/nodes/heimdall/gpu/0/services", + json={"services": ["cf-text"]}, + ) + assert r.status_code == 422 + assert "VRAM" in r.json()["detail"] + + +def test_update_services_unknown_service_returns_422(client, tmp_path): + _setup_profile(tmp_path) + r = client.post( + "/api/nodes-mgmt/nodes/heimdall/gpu/0/services", + json={"services": ["not-a-real-service"]}, + ) + assert r.status_code == 422 + + +def test_update_services_reload_failure_returns_reloaded_false(client, tmp_path): + """YAML saved but coordinator reload fails — ok: true, reloaded: false.""" + _setup_profile(tmp_path) + + mock_reload = MagicMock() + mock_reload.status_code = 500 + + with patch("httpx.post", return_value=mock_reload): + r = client.post( + "/api/nodes-mgmt/nodes/heimdall/gpu/0/services", + json={"services": ["ollama"]}, + ) + + assert r.status_code == 200 + data = r.json() + assert data["ok"] is True + assert data["reloaded"] is False