diff --git a/app/nodes.py b/app/nodes.py index 891f56d..04f557a 100644 --- a/app/nodes.py +++ b/app/nodes.py @@ -11,11 +11,13 @@ Module-level globals follow the set_config_dir() testability pattern from cforch from __future__ import annotations import logging +import os from pathlib import Path from urllib.parse import urlparse import yaml from fastapi import APIRouter, HTTPException +from pydantic import BaseModel logger = logging.getLogger(__name__) @@ -187,3 +189,105 @@ def list_nodes() -> list: }) return result + + +@router.get("/nodes/{node_id}/profile") +def get_node_profile(node_id: str) -> dict: + """Return the full parsed profile YAML for a node.""" + p = _profile_path(node_id) + if p is None or not p.exists(): + raise HTTPException(404, f"No profile found for node {node_id}") + try: + data = yaml.safe_load(p.read_text(encoding="utf-8")) or {} + except yaml.YAMLError as exc: + raise HTTPException(500, f"Malformed profile YAML: {exc}") + return data + + +class UpdateServicesRequest(BaseModel): + services: list[str] + + +@router.post("/nodes/{node_id}/gpu/{gpu_id}/services") +def update_gpu_services(node_id: str, gpu_id: int, body: UpdateServicesRequest) -> dict: + """Set service assignment for a GPU with compatibility validation, then atomic write.""" + import httpx + + cfg = _load_config() + coordinator_url = cfg.get("coordinator_url", "") or "" + + p = _profile_path(node_id) + if p is None or not p.exists(): + raise HTTPException(404, f"No profile found for node {node_id}") + + try: + profile = yaml.safe_load(p.read_text(encoding="utf-8")) or {} + except yaml.YAMLError as exc: + raise HTTPException(500, f"Malformed profile YAML: {exc}") + + nodes_section = profile.get("nodes", {}) or {} + node_entry = nodes_section.get(node_id, {}) or {} + gpu_list = node_entry.get("gpus", []) or [] + + gpu_entry = next( + (g for g in gpu_list if isinstance(g, dict) and g.get("id") == gpu_id), + None, + ) + if gpu_entry is None: + raise HTTPException(404, f"GPU {gpu_id} not found in profile for node {node_id}") + + gpu_compute_cap: float = gpu_entry.get("compute_cap") or 0.0 + gpu_vram_mb: int = gpu_entry.get("vram_mb") or 0 + services_def = profile.get("services", {}) or {} + + for svc_name in body.services: + if svc_name not in services_def: + raise HTTPException(422, f"Service '{svc_name}' not defined in profile services dict") + svc = services_def[svc_name] + min_cap: float = svc.get("min_compute_cap", 0.0) or 0.0 + if gpu_compute_cap < min_cap: + raise HTTPException( + 422, + f"Service '{svc_name}' requires compute_cap >= {min_cap}; GPU has {gpu_compute_cap}", + ) + catalog = svc.get("catalog", {}) or {} + min_catalog_vram = ( + min((m.get("vram_mb", 0) for m in catalog.values()), default=0) + if catalog else svc.get("max_mb", 0) + ) + if gpu_vram_mb < min_catalog_vram: + raise HTTPException( + 422, + f"Service '{svc_name}' requires {min_catalog_vram} MB VRAM; GPU has {gpu_vram_mb} MB", + ) + + # Immutable update of GPU services list + new_gpu_list = [ + ({**g, "services": body.services} if isinstance(g, dict) and g.get("id") == gpu_id else g) + for g in gpu_list + ] + new_profile = { + **profile, + "nodes": { + **nodes_section, + node_id: {**node_entry, "gpus": new_gpu_list}, + }, + } + + # Atomic write: write to .tmp then rename + tmp_yaml = Path(str(p) + ".tmp") + tmp_yaml.write_text(yaml.dump(new_profile, default_flow_style=False), encoding="utf-8") + os.replace(tmp_yaml, p) + + # Trigger coordinator profile reload + reloaded = False + if coordinator_url: + try: + rr = httpx.post( + f"{coordinator_url}/api/nodes/{node_id}/reload-profile", timeout=5.0 + ) + reloaded = rr.status_code < 300 + except Exception: + pass + + return {"ok": True, "reloaded": reloaded, "warnings": []} diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 52df098..a475286 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -164,3 +164,37 @@ def test_list_nodes_marks_running_services(client, tmp_path): data = r.json() assert data[0]["gpus"][0]["services_running"] == ["cf-text"] + + +# ── GET /api/nodes-mgmt/nodes/{node_id}/profile ──────────────────────────────── + +def test_get_profile_returns_parsed_yaml(client, tmp_path): + profiles_dir = tmp_path / "profiles" + _write_config(tmp_path, {"profiles_dir": str(profiles_dir)}) + profile = { + "services": {"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "catalog": {}}}, + "nodes": {"heimdall": {"gpus": [], "agent_url": "http://10.1.10.71:7701"}}, + } + _write_profile(profiles_dir, "heimdall", profile) + + r = client.get("/api/nodes-mgmt/nodes/heimdall/profile") + assert r.status_code == 200 + data = r.json() + assert "services" in data + assert "cf-text" in data["services"] + + +def test_get_profile_404_when_missing(client, tmp_path): + _write_config(tmp_path, {"profiles_dir": str(tmp_path / "profiles")}) + r = client.get("/api/nodes-mgmt/nodes/nonexistent/profile") + assert r.status_code == 404 + + +def test_get_profile_500_on_malformed_yaml(client, tmp_path): + profiles_dir = tmp_path / "profiles" + profiles_dir.mkdir() + _write_config(tmp_path, {"profiles_dir": str(profiles_dir)}) + (profiles_dir / "bad.yaml").write_text("key: [unclosed", encoding="utf-8") + + r = client.get("/api/nodes-mgmt/nodes/bad/profile") + assert r.status_code == 500