feat: add Ollama list/pull-SSE/delete endpoints
This commit is contained in:
parent
55b017ba3b
commit
5702a7190b
2 changed files with 190 additions and 0 deletions
66
app/nodes.py
66
app/nodes.py
|
|
@ -17,6 +17,7 @@ from urllib.parse import urlparse
|
|||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -291,3 +292,68 @@ def update_gpu_services(node_id: str, gpu_id: int, body: UpdateServicesRequest)
|
|||
logger.warning("Coordinator reload failed for node %s: %s", node_id, exc)
|
||||
|
||||
return {"ok": True, "reloaded": reloaded, "warnings": []}
|
||||
|
||||
# ── Ollama model management ────────────────────────────────────────────────────
|
||||
|
||||
class PullRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@router.get("/nodes/{node_id}/models/ollama")
|
||||
def list_ollama_models(node_id: str) -> dict:
|
||||
"""Proxy GET {ollama_url}/api/tags for a specific node."""
|
||||
import httpx
|
||||
|
||||
ollama_url = _get_ollama_url(node_id)
|
||||
try:
|
||||
r = httpx.get(f"{ollama_url}/api/tags", timeout=10.0)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
except Exception as exc:
|
||||
return {"error": str(exc)}
|
||||
|
||||
|
||||
@router.post("/nodes/{node_id}/models/ollama/pull")
|
||||
def pull_ollama_model(node_id: str, body: PullRequest) -> StreamingResponse:
|
||||
"""Stream Ollama pull progress as SSE events."""
|
||||
import httpx
|
||||
|
||||
if not body.name:
|
||||
raise HTTPException(400, "name is required")
|
||||
|
||||
ollama_url = _get_ollama_url(node_id)
|
||||
|
||||
def stream():
|
||||
try:
|
||||
with httpx.stream(
|
||||
"POST",
|
||||
f"{ollama_url}/api/pull",
|
||||
json={"name": body.name, "stream": True},
|
||||
timeout=300.0,
|
||||
) as resp:
|
||||
for line in resp.iter_lines():
|
||||
if line:
|
||||
yield f"data: {line}\n\n"
|
||||
except Exception as exc:
|
||||
import json
|
||||
yield f"data: {json.dumps({'error': str(exc)})}\n\n"
|
||||
|
||||
return StreamingResponse(stream(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.delete("/nodes/{node_id}/models/ollama/{name:path}")
|
||||
def delete_ollama_model(node_id: str, name: str) -> dict:
|
||||
"""Proxy DELETE to Ollama for a specific node."""
|
||||
import httpx
|
||||
|
||||
ollama_url = _get_ollama_url(node_id)
|
||||
try:
|
||||
r = httpx.request("DELETE", f"{ollama_url}/api/delete", json={"name": name}, timeout=10.0)
|
||||
if r.status_code == 404:
|
||||
raise HTTPException(404, f"Model '{name}' not found on node {node_id}")
|
||||
r.raise_for_status()
|
||||
return {"ok": True}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, f"Ollama unreachable: {exc}")
|
||||
|
|
|
|||
|
|
@ -345,3 +345,127 @@ def test_update_services_reload_failure_returns_reloaded_false(client, tmp_path)
|
|||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["reloaded"] is False
|
||||
|
||||
# ── Ollama endpoints ───────────────────────────────────────────────────────────
|
||||
|
||||
_OLLAMA_PROFILE = {
|
||||
"services": {},
|
||||
"nodes": {
|
||||
"heimdall": {
|
||||
"gpus": [],
|
||||
"agent_url": "http://10.1.10.71:7701",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_list_ollama_models_proxies_tags(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_tags = MagicMock()
|
||||
mock_tags.raise_for_status = MagicMock()
|
||||
mock_tags.json.return_value = {
|
||||
"models": [{"name": "nomic-embed-text", "size": 274000000, "modified_at": "2025-01-01"}]
|
||||
}
|
||||
|
||||
with patch("httpx.get", return_value=mock_tags):
|
||||
r = client.get("/api/nodes-mgmt/nodes/heimdall/models/ollama")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data["models"]) == 1
|
||||
assert data["models"][0]["name"] == "nomic-embed-text"
|
||||
|
||||
|
||||
def test_list_ollama_models_unreachable_returns_error(client, tmp_path):
|
||||
import httpx as _httpx
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
with patch("httpx.get", side_effect=_httpx.ConnectError("refused")):
|
||||
r = client.get("/api/nodes-mgmt/nodes/heimdall/models/ollama")
|
||||
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert "error" in data
|
||||
|
||||
|
||||
def test_pull_ollama_model_streams_sse(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.iter_lines.return_value = iter([
|
||||
'{"status": "pulling manifest"}',
|
||||
'{"status": "pulling", "digest": "sha256-abc", "total": 1000, "completed": 500}',
|
||||
'{"status": "success"}',
|
||||
])
|
||||
|
||||
with patch("httpx.stream") as mock_stream_fn:
|
||||
mock_stream_fn.return_value.__enter__ = MagicMock(return_value=mock_resp)
|
||||
mock_stream_fn.return_value.__exit__ = MagicMock(return_value=False)
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/models/ollama/pull",
|
||||
json={"name": "nomic-embed-text"},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
body = r.text
|
||||
assert 'data: {"status": "pulling manifest"}' in body
|
||||
assert 'data: {"status": "success"}' in body
|
||||
|
||||
|
||||
def test_pull_ollama_model_error_event_in_stream(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.iter_lines.return_value = iter([
|
||||
'{"error": "permission denied: /var/lib/ollama/sha256-abc-partial-0"}',
|
||||
])
|
||||
|
||||
with patch("httpx.stream") as mock_stream_fn:
|
||||
mock_stream_fn.return_value.__enter__ = MagicMock(return_value=mock_resp)
|
||||
mock_stream_fn.return_value.__exit__ = MagicMock(return_value=False)
|
||||
r = client.post(
|
||||
"/api/nodes-mgmt/nodes/heimdall/models/ollama/pull",
|
||||
json={"name": "nomic-embed-text"},
|
||||
)
|
||||
|
||||
assert r.status_code == 200
|
||||
assert "permission denied" in r.text
|
||||
|
||||
|
||||
def test_delete_ollama_model_proxies_delete(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_del = MagicMock()
|
||||
mock_del.status_code = 200
|
||||
mock_del.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.request", return_value=mock_del):
|
||||
r = client.delete("/api/nodes-mgmt/nodes/heimdall/models/ollama/nomic-embed-text")
|
||||
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"ok": True}
|
||||
|
||||
|
||||
def test_delete_ollama_model_404_when_not_found(client, tmp_path):
|
||||
profiles_dir = tmp_path / "profiles"
|
||||
_write_config(tmp_path, {"profiles_dir": str(profiles_dir)})
|
||||
_write_profile(profiles_dir, "heimdall", _OLLAMA_PROFILE)
|
||||
|
||||
mock_del = MagicMock()
|
||||
mock_del.status_code = 404
|
||||
|
||||
with patch("httpx.request", return_value=mock_del):
|
||||
r = client.delete("/api/nodes-mgmt/nodes/heimdall/models/ollama/missing-model")
|
||||
|
||||
assert r.status_code == 404
|
||||
|
|
|
|||
Loading…
Reference in a new issue