feat: add GET /api/nodes-mgmt/nodes/{node_id}/profile endpoint

This commit is contained in:
pyr0ball 2026-05-05 20:31:22 -07:00
parent 47cb9f661f
commit fd8cb622a1
2 changed files with 138 additions and 0 deletions

View file

@ -11,11 +11,13 @@ Module-level globals follow the set_config_dir() testability pattern from cforch
from __future__ import annotations
import logging
import os
from pathlib import Path
from urllib.parse import urlparse
import yaml
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
logger = logging.getLogger(__name__)
@ -187,3 +189,105 @@ def list_nodes() -> list:
})
return result
@router.get("/nodes/{node_id}/profile")
def get_node_profile(node_id: str) -> dict:
"""Return the full parsed profile YAML for a node."""
p = _profile_path(node_id)
if p is None or not p.exists():
raise HTTPException(404, f"No profile found for node {node_id}")
try:
data = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
raise HTTPException(500, f"Malformed profile YAML: {exc}")
return data
class UpdateServicesRequest(BaseModel):
services: list[str]
@router.post("/nodes/{node_id}/gpu/{gpu_id}/services")
def update_gpu_services(node_id: str, gpu_id: int, body: UpdateServicesRequest) -> dict:
"""Set service assignment for a GPU with compatibility validation, then atomic write."""
import httpx
cfg = _load_config()
coordinator_url = cfg.get("coordinator_url", "") or ""
p = _profile_path(node_id)
if p is None or not p.exists():
raise HTTPException(404, f"No profile found for node {node_id}")
try:
profile = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
raise HTTPException(500, f"Malformed profile YAML: {exc}")
nodes_section = profile.get("nodes", {}) or {}
node_entry = nodes_section.get(node_id, {}) or {}
gpu_list = node_entry.get("gpus", []) or []
gpu_entry = next(
(g for g in gpu_list if isinstance(g, dict) and g.get("id") == gpu_id),
None,
)
if gpu_entry is None:
raise HTTPException(404, f"GPU {gpu_id} not found in profile for node {node_id}")
gpu_compute_cap: float = gpu_entry.get("compute_cap") or 0.0
gpu_vram_mb: int = gpu_entry.get("vram_mb") or 0
services_def = profile.get("services", {}) or {}
for svc_name in body.services:
if svc_name not in services_def:
raise HTTPException(422, f"Service '{svc_name}' not defined in profile services dict")
svc = services_def[svc_name]
min_cap: float = svc.get("min_compute_cap", 0.0) or 0.0
if gpu_compute_cap < min_cap:
raise HTTPException(
422,
f"Service '{svc_name}' requires compute_cap >= {min_cap}; GPU has {gpu_compute_cap}",
)
catalog = svc.get("catalog", {}) or {}
min_catalog_vram = (
min((m.get("vram_mb", 0) for m in catalog.values()), default=0)
if catalog else svc.get("max_mb", 0)
)
if gpu_vram_mb < min_catalog_vram:
raise HTTPException(
422,
f"Service '{svc_name}' requires {min_catalog_vram} MB VRAM; GPU has {gpu_vram_mb} MB",
)
# Immutable update of GPU services list
new_gpu_list = [
({**g, "services": body.services} if isinstance(g, dict) and g.get("id") == gpu_id else g)
for g in gpu_list
]
new_profile = {
**profile,
"nodes": {
**nodes_section,
node_id: {**node_entry, "gpus": new_gpu_list},
},
}
# Atomic write: write to .tmp then rename
tmp_yaml = Path(str(p) + ".tmp")
tmp_yaml.write_text(yaml.dump(new_profile, default_flow_style=False), encoding="utf-8")
os.replace(tmp_yaml, p)
# Trigger coordinator profile reload
reloaded = False
if coordinator_url:
try:
rr = httpx.post(
f"{coordinator_url}/api/nodes/{node_id}/reload-profile", timeout=5.0
)
reloaded = rr.status_code < 300
except Exception:
pass
return {"ok": True, "reloaded": reloaded, "warnings": []}

View file

@ -164,3 +164,37 @@ def test_list_nodes_marks_running_services(client, tmp_path):
data = r.json()
assert data[0]["gpus"][0]["services_running"] == ["cf-text"]
# ── GET /api/nodes-mgmt/nodes/{node_id}/profile ────────────────────────────────
def test_get_profile_returns_parsed_yaml(client, tmp_path):
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
profile = {
"services": {"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "catalog": {}}},
"nodes": {"heimdall": {"gpus": [], "agent_url": "http://10.1.10.71:7701"}},
}
_write_profile(profiles_dir, "heimdall", profile)
r = client.get("/api/nodes-mgmt/nodes/heimdall/profile")
assert r.status_code == 200
data = r.json()
assert "services" in data
assert "cf-text" in data["services"]
def test_get_profile_404_when_missing(client, tmp_path):
_write_config(tmp_path, {"profiles_dir": str(tmp_path / "profiles")})
r = client.get("/api/nodes-mgmt/nodes/nonexistent/profile")
assert r.status_code == 404
def test_get_profile_500_on_malformed_yaml(client, tmp_path):
profiles_dir = tmp_path / "profiles"
profiles_dir.mkdir()
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
(profiles_dir / "bad.yaml").write_text("key: [unclosed", encoding="utf-8")
r = client.get("/api/nodes-mgmt/nodes/bad/profile")
assert r.status_code == 500