feat: add GET /api/nodes-mgmt/nodes/{node_id}/profile endpoint

This commit is contained in:
pyr0ball 2026-05-05 20:31:22 -07:00
parent 47cb9f661f
commit fd8cb622a1
2 changed files with 138 additions and 0 deletions

View file

@ -11,11 +11,13 @@ Module-level globals follow the set_config_dir() testability pattern from cforch
from __future__ import annotations from __future__ import annotations
import logging import logging
import os
from pathlib import Path from pathlib import Path
from urllib.parse import urlparse from urllib.parse import urlparse
import yaml import yaml
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -187,3 +189,105 @@ def list_nodes() -> list:
}) })
return result return result
@router.get("/nodes/{node_id}/profile")
def get_node_profile(node_id: str) -> dict:
"""Return the full parsed profile YAML for a node."""
p = _profile_path(node_id)
if p is None or not p.exists():
raise HTTPException(404, f"No profile found for node {node_id}")
try:
data = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
raise HTTPException(500, f"Malformed profile YAML: {exc}")
return data
class UpdateServicesRequest(BaseModel):
services: list[str]
@router.post("/nodes/{node_id}/gpu/{gpu_id}/services")
def update_gpu_services(node_id: str, gpu_id: int, body: UpdateServicesRequest) -> dict:
"""Set service assignment for a GPU with compatibility validation, then atomic write."""
import httpx
cfg = _load_config()
coordinator_url = cfg.get("coordinator_url", "") or ""
p = _profile_path(node_id)
if p is None or not p.exists():
raise HTTPException(404, f"No profile found for node {node_id}")
try:
profile = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
raise HTTPException(500, f"Malformed profile YAML: {exc}")
nodes_section = profile.get("nodes", {}) or {}
node_entry = nodes_section.get(node_id, {}) or {}
gpu_list = node_entry.get("gpus", []) or []
gpu_entry = next(
(g for g in gpu_list if isinstance(g, dict) and g.get("id") == gpu_id),
None,
)
if gpu_entry is None:
raise HTTPException(404, f"GPU {gpu_id} not found in profile for node {node_id}")
gpu_compute_cap: float = gpu_entry.get("compute_cap") or 0.0
gpu_vram_mb: int = gpu_entry.get("vram_mb") or 0
services_def = profile.get("services", {}) or {}
for svc_name in body.services:
if svc_name not in services_def:
raise HTTPException(422, f"Service '{svc_name}' not defined in profile services dict")
svc = services_def[svc_name]
min_cap: float = svc.get("min_compute_cap", 0.0) or 0.0
if gpu_compute_cap < min_cap:
raise HTTPException(
422,
f"Service '{svc_name}' requires compute_cap >= {min_cap}; GPU has {gpu_compute_cap}",
)
catalog = svc.get("catalog", {}) or {}
min_catalog_vram = (
min((m.get("vram_mb", 0) for m in catalog.values()), default=0)
if catalog else svc.get("max_mb", 0)
)
if gpu_vram_mb < min_catalog_vram:
raise HTTPException(
422,
f"Service '{svc_name}' requires {min_catalog_vram} MB VRAM; GPU has {gpu_vram_mb} MB",
)
# Immutable update of GPU services list
new_gpu_list = [
({**g, "services": body.services} if isinstance(g, dict) and g.get("id") == gpu_id else g)
for g in gpu_list
]
new_profile = {
**profile,
"nodes": {
**nodes_section,
node_id: {**node_entry, "gpus": new_gpu_list},
},
}
# Atomic write: write to .tmp then rename
tmp_yaml = Path(str(p) + ".tmp")
tmp_yaml.write_text(yaml.dump(new_profile, default_flow_style=False), encoding="utf-8")
os.replace(tmp_yaml, p)
# Trigger coordinator profile reload
reloaded = False
if coordinator_url:
try:
rr = httpx.post(
f"{coordinator_url}/api/nodes/{node_id}/reload-profile", timeout=5.0
)
reloaded = rr.status_code < 300
except Exception:
pass
return {"ok": True, "reloaded": reloaded, "warnings": []}

View file

@ -164,3 +164,37 @@ def test_list_nodes_marks_running_services(client, tmp_path):
data = r.json() data = r.json()
assert data[0]["gpus"][0]["services_running"] == ["cf-text"] assert data[0]["gpus"][0]["services_running"] == ["cf-text"]
# ── GET /api/nodes-mgmt/nodes/{node_id}/profile ────────────────────────────────
def test_get_profile_returns_parsed_yaml(client, tmp_path):
profiles_dir = tmp_path / "profiles"
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
profile = {
"services": {"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "catalog": {}}},
"nodes": {"heimdall": {"gpus": [], "agent_url": "http://10.1.10.71:7701"}},
}
_write_profile(profiles_dir, "heimdall", profile)
r = client.get("/api/nodes-mgmt/nodes/heimdall/profile")
assert r.status_code == 200
data = r.json()
assert "services" in data
assert "cf-text" in data["services"]
def test_get_profile_404_when_missing(client, tmp_path):
_write_config(tmp_path, {"profiles_dir": str(tmp_path / "profiles")})
r = client.get("/api/nodes-mgmt/nodes/nonexistent/profile")
assert r.status_code == 404
def test_get_profile_500_on_malformed_yaml(client, tmp_path):
profiles_dir = tmp_path / "profiles"
profiles_dir.mkdir()
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
(profiles_dir / "bad.yaml").write_text("key: [unclosed", encoding="utf-8")
r = client.get("/api/nodes-mgmt/nodes/bad/profile")
assert r.status_code == 500