feat: add GET /api/nodes-mgmt/nodes/{node_id}/profile endpoint
This commit is contained in:
parent
47cb9f661f
commit
fd8cb622a1
2 changed files with 138 additions and 0 deletions
104
app/nodes.py
104
app/nodes.py
|
|
@ -11,11 +11,13 @@ Module-level globals follow the set_config_dir() testability pattern from cforch
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -187,3 +189,105 @@ def list_nodes() -> list:
|
||||||
})
|
})
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/nodes/{node_id}/profile")
|
||||||
|
def get_node_profile(node_id: str) -> dict:
|
||||||
|
"""Return the full parsed profile YAML for a node."""
|
||||||
|
p = _profile_path(node_id)
|
||||||
|
if p is None or not p.exists():
|
||||||
|
raise HTTPException(404, f"No profile found for node {node_id}")
|
||||||
|
try:
|
||||||
|
data = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||||
|
except yaml.YAMLError as exc:
|
||||||
|
raise HTTPException(500, f"Malformed profile YAML: {exc}")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateServicesRequest(BaseModel):
|
||||||
|
services: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/nodes/{node_id}/gpu/{gpu_id}/services")
|
||||||
|
def update_gpu_services(node_id: str, gpu_id: int, body: UpdateServicesRequest) -> dict:
|
||||||
|
"""Set service assignment for a GPU with compatibility validation, then atomic write."""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
cfg = _load_config()
|
||||||
|
coordinator_url = cfg.get("coordinator_url", "") or ""
|
||||||
|
|
||||||
|
p = _profile_path(node_id)
|
||||||
|
if p is None or not p.exists():
|
||||||
|
raise HTTPException(404, f"No profile found for node {node_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
profile = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||||||
|
except yaml.YAMLError as exc:
|
||||||
|
raise HTTPException(500, f"Malformed profile YAML: {exc}")
|
||||||
|
|
||||||
|
nodes_section = profile.get("nodes", {}) or {}
|
||||||
|
node_entry = nodes_section.get(node_id, {}) or {}
|
||||||
|
gpu_list = node_entry.get("gpus", []) or []
|
||||||
|
|
||||||
|
gpu_entry = next(
|
||||||
|
(g for g in gpu_list if isinstance(g, dict) and g.get("id") == gpu_id),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if gpu_entry is None:
|
||||||
|
raise HTTPException(404, f"GPU {gpu_id} not found in profile for node {node_id}")
|
||||||
|
|
||||||
|
gpu_compute_cap: float = gpu_entry.get("compute_cap") or 0.0
|
||||||
|
gpu_vram_mb: int = gpu_entry.get("vram_mb") or 0
|
||||||
|
services_def = profile.get("services", {}) or {}
|
||||||
|
|
||||||
|
for svc_name in body.services:
|
||||||
|
if svc_name not in services_def:
|
||||||
|
raise HTTPException(422, f"Service '{svc_name}' not defined in profile services dict")
|
||||||
|
svc = services_def[svc_name]
|
||||||
|
min_cap: float = svc.get("min_compute_cap", 0.0) or 0.0
|
||||||
|
if gpu_compute_cap < min_cap:
|
||||||
|
raise HTTPException(
|
||||||
|
422,
|
||||||
|
f"Service '{svc_name}' requires compute_cap >= {min_cap}; GPU has {gpu_compute_cap}",
|
||||||
|
)
|
||||||
|
catalog = svc.get("catalog", {}) or {}
|
||||||
|
min_catalog_vram = (
|
||||||
|
min((m.get("vram_mb", 0) for m in catalog.values()), default=0)
|
||||||
|
if catalog else svc.get("max_mb", 0)
|
||||||
|
)
|
||||||
|
if gpu_vram_mb < min_catalog_vram:
|
||||||
|
raise HTTPException(
|
||||||
|
422,
|
||||||
|
f"Service '{svc_name}' requires {min_catalog_vram} MB VRAM; GPU has {gpu_vram_mb} MB",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Immutable update of GPU services list
|
||||||
|
new_gpu_list = [
|
||||||
|
({**g, "services": body.services} if isinstance(g, dict) and g.get("id") == gpu_id else g)
|
||||||
|
for g in gpu_list
|
||||||
|
]
|
||||||
|
new_profile = {
|
||||||
|
**profile,
|
||||||
|
"nodes": {
|
||||||
|
**nodes_section,
|
||||||
|
node_id: {**node_entry, "gpus": new_gpu_list},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Atomic write: write to .tmp then rename
|
||||||
|
tmp_yaml = Path(str(p) + ".tmp")
|
||||||
|
tmp_yaml.write_text(yaml.dump(new_profile, default_flow_style=False), encoding="utf-8")
|
||||||
|
os.replace(tmp_yaml, p)
|
||||||
|
|
||||||
|
# Trigger coordinator profile reload
|
||||||
|
reloaded = False
|
||||||
|
if coordinator_url:
|
||||||
|
try:
|
||||||
|
rr = httpx.post(
|
||||||
|
f"{coordinator_url}/api/nodes/{node_id}/reload-profile", timeout=5.0
|
||||||
|
)
|
||||||
|
reloaded = rr.status_code < 300
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {"ok": True, "reloaded": reloaded, "warnings": []}
|
||||||
|
|
|
||||||
|
|
@ -164,3 +164,37 @@ def test_list_nodes_marks_running_services(client, tmp_path):
|
||||||
|
|
||||||
data = r.json()
|
data = r.json()
|
||||||
assert data[0]["gpus"][0]["services_running"] == ["cf-text"]
|
assert data[0]["gpus"][0]["services_running"] == ["cf-text"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /api/nodes-mgmt/nodes/{node_id}/profile ────────────────────────────────
|
||||||
|
|
||||||
|
def test_get_profile_returns_parsed_yaml(client, tmp_path):
|
||||||
|
profiles_dir = tmp_path / "profiles"
|
||||||
|
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||||
|
profile = {
|
||||||
|
"services": {"cf-text": {"min_compute_cap": 7.0, "max_mb": 8192, "catalog": {}}},
|
||||||
|
"nodes": {"heimdall": {"gpus": [], "agent_url": "http://10.1.10.71:7701"}},
|
||||||
|
}
|
||||||
|
_write_profile(profiles_dir, "heimdall", profile)
|
||||||
|
|
||||||
|
r = client.get("/api/nodes-mgmt/nodes/heimdall/profile")
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert "services" in data
|
||||||
|
assert "cf-text" in data["services"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_profile_404_when_missing(client, tmp_path):
|
||||||
|
_write_config(tmp_path, {"profiles_dir": str(tmp_path / "profiles")})
|
||||||
|
r = client.get("/api/nodes-mgmt/nodes/nonexistent/profile")
|
||||||
|
assert r.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_profile_500_on_malformed_yaml(client, tmp_path):
|
||||||
|
profiles_dir = tmp_path / "profiles"
|
||||||
|
profiles_dir.mkdir()
|
||||||
|
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||||
|
(profiles_dir / "bad.yaml").write_text("key: [unclosed", encoding="utf-8")
|
||||||
|
|
||||||
|
r = client.get("/api/nodes-mgmt/nodes/bad/profile")
|
||||||
|
assert r.status_code == 500
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue