fix: align train job/results API envelope, config_json key, progress SSE, dashboard model_key
- GET /api/train/jobs now returns {"jobs":[...]} instead of bare array
- GET /api/train/results now returns {"results":[...]} instead of bare array
- POST /api/train/jobs body key renamed config -> config_json to match Pydantic model
- SSE log handler now handles 'progress' event type (backend never emits 'log')
- Dashboard _get_active_jobs() adds model_key to SELECT and return dict
- corrections.py docstring updated: both /api/corrections and /api/sft prefixes noted
- test_train.py assertions updated to unwrap new envelope shapes
This commit is contained in:
parent
13d1a394d5
commit
e11db5ccd9
5 changed files with 14 additions and 14 deletions
|
|
@ -155,9 +155,9 @@ def _get_active_jobs() -> list[dict]:
|
|||
_init_db()
|
||||
with _db() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT id, type, status FROM jobs WHERE status IN ('queued', 'running')"
|
||||
"SELECT id, type, model_key, status FROM jobs WHERE status IN ('queued', 'running')"
|
||||
).fetchall()
|
||||
return [{"id": r["id"], "type": r["type"], "status": r["status"]} for r in rows]
|
||||
return [{"id": r["id"], "type": r["type"], "model_key": r["model_key"], "status": r["status"]} for r in rows]
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to query train jobs DB: %s", exc)
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Avocet -- SFT candidate corrections API (moved from app/sft.py).
|
||||
|
||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||
api.py includes this router with prefix="/api/sft".
|
||||
Primary prefix: /api/corrections (backward-compat alias: /api/sft -- pending Vue SPA migration)
|
||||
|
||||
Module-level globals (_DATA_DIR, _CONFIG_DIR) follow the same
|
||||
testability pattern as api.py -- override them via set_data_dir() and
|
||||
|
|
|
|||
|
|
@ -174,11 +174,11 @@ class CreateJobRequest(BaseModel):
|
|||
# -- Routes ------------------------------------------------------------
|
||||
|
||||
@router.get("/jobs")
|
||||
def list_jobs() -> list[dict]:
|
||||
def list_jobs() -> dict:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
rows = conn.execute("SELECT * FROM jobs ORDER BY created_at DESC").fetchall()
|
||||
return [_row_to_dict(r) for r in rows]
|
||||
return {"jobs": [_row_to_dict(r) for r in rows]}
|
||||
|
||||
|
||||
@router.post("/jobs")
|
||||
|
|
@ -321,9 +321,9 @@ def run_job(job_id: str) -> StreamingResponse:
|
|||
|
||||
|
||||
@router.get("/results")
|
||||
def list_results() -> list[dict]:
|
||||
def list_results() -> dict:
|
||||
if not _MODELS_DIR.exists():
|
||||
return []
|
||||
return {"results": []}
|
||||
results = []
|
||||
for sub in _MODELS_DIR.iterdir():
|
||||
if not sub.is_dir():
|
||||
|
|
@ -336,4 +336,4 @@ def list_results() -> list[dict]:
|
|||
results.append(info)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read training_info.json from %s: %s", info_path, exc)
|
||||
return results
|
||||
return {"results": results}
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def _parse_sse(content: bytes) -> list[dict]:
|
|||
def test_list_jobs_empty(client):
|
||||
r = client.get("/api/train/jobs")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
assert r.json() == {"jobs": []}
|
||||
|
||||
|
||||
def test_create_job_returns_queued_record(client):
|
||||
|
|
@ -57,7 +57,7 @@ def test_create_job_appears_in_list(client):
|
|||
client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
r = client.get("/api/train/jobs")
|
||||
assert r.status_code == 200
|
||||
assert len(r.json()) == 1
|
||||
assert len(r.json()["jobs"]) == 1
|
||||
|
||||
|
||||
def test_get_job_returns_record(client):
|
||||
|
|
@ -171,7 +171,7 @@ def test_run_unknown_job_returns_404(client):
|
|||
def test_results_empty_when_no_models_dir(client):
|
||||
r = client.get("/api/train/results")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
assert r.json() == {"results": []}
|
||||
|
||||
|
||||
def test_results_returns_training_info(client, tmp_path):
|
||||
|
|
@ -184,4 +184,4 @@ def test_results_returns_training_info(client, tmp_path):
|
|||
r = client.get("/api/train/results")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert any(d["name"] == "avocet-deberta-small" for d in data)
|
||||
assert any(d["name"] == "avocet-deberta-small" for d in data["results"])
|
||||
|
|
|
|||
|
|
@ -205,7 +205,7 @@ async function submitJob() {
|
|||
body: JSON.stringify({
|
||||
type: form.value.type,
|
||||
model_key: form.value.model_key,
|
||||
config,
|
||||
config_json: config,
|
||||
}),
|
||||
})
|
||||
if (!res.ok) {
|
||||
|
|
@ -254,7 +254,7 @@ function openLog(id: string) {
|
|||
closeSSE = useApiSSE(
|
||||
`/api/train/jobs/${encodeURIComponent(id)}/run`,
|
||||
(data) => {
|
||||
if (data.type === 'log' || data.type === 'error') {
|
||||
if (data.type === 'log' || data.type === 'progress' || data.type === 'error') {
|
||||
logLines.value = [...logLines.value, String(data.message ?? '')]
|
||||
nextTick(() => {
|
||||
if (logPanelEl.value) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue