fix: align train job/results API envelope, config_json key, progress SSE, dashboard model_key
- GET /api/train/jobs now returns {"jobs":[...]} instead of bare array
- GET /api/train/results now returns {"results":[...]} instead of bare array
- POST /api/train/jobs body key renamed config -> config_json to match Pydantic model
- SSE log handler now handles 'progress' event type (backend never emits 'log')
- Dashboard _get_active_jobs() adds model_key to SELECT and return dict
- corrections.py docstring updated: both /api/corrections and /api/sft prefixes noted
- test_train.py assertions updated to unwrap new envelope shapes
This commit is contained in:
parent
13d1a394d5
commit
e11db5ccd9
5 changed files with 14 additions and 14 deletions
|
|
@ -155,9 +155,9 @@ def _get_active_jobs() -> list[dict]:
|
||||||
_init_db()
|
_init_db()
|
||||||
with _db() as conn:
|
with _db() as conn:
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
"SELECT id, type, status FROM jobs WHERE status IN ('queued', 'running')"
|
"SELECT id, type, model_key, status FROM jobs WHERE status IN ('queued', 'running')"
|
||||||
).fetchall()
|
).fetchall()
|
||||||
return [{"id": r["id"], "type": r["type"], "status": r["status"]} for r in rows]
|
return [{"id": r["id"], "type": r["type"], "model_key": r["model_key"], "status": r["status"]} for r in rows]
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Failed to query train jobs DB: %s", exc)
|
logger.warning("Failed to query train jobs DB: %s", exc)
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""Avocet -- SFT candidate corrections API (moved from app/sft.py).
|
"""Avocet -- SFT candidate corrections API (moved from app/sft.py).
|
||||||
|
|
||||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||||
api.py includes this router with prefix="/api/sft".
|
Primary prefix: /api/corrections (backward-compat alias: /api/sft -- pending Vue SPA migration)
|
||||||
|
|
||||||
Module-level globals (_DATA_DIR, _CONFIG_DIR) follow the same
|
Module-level globals (_DATA_DIR, _CONFIG_DIR) follow the same
|
||||||
testability pattern as api.py -- override them via set_data_dir() and
|
testability pattern as api.py -- override them via set_data_dir() and
|
||||||
|
|
|
||||||
|
|
@ -174,11 +174,11 @@ class CreateJobRequest(BaseModel):
|
||||||
# -- Routes ------------------------------------------------------------
|
# -- Routes ------------------------------------------------------------
|
||||||
|
|
||||||
@router.get("/jobs")
|
@router.get("/jobs")
|
||||||
def list_jobs() -> list[dict]:
|
def list_jobs() -> dict:
|
||||||
_init_db()
|
_init_db()
|
||||||
with _db() as conn:
|
with _db() as conn:
|
||||||
rows = conn.execute("SELECT * FROM jobs ORDER BY created_at DESC").fetchall()
|
rows = conn.execute("SELECT * FROM jobs ORDER BY created_at DESC").fetchall()
|
||||||
return [_row_to_dict(r) for r in rows]
|
return {"jobs": [_row_to_dict(r) for r in rows]}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/jobs")
|
@router.post("/jobs")
|
||||||
|
|
@ -321,9 +321,9 @@ def run_job(job_id: str) -> StreamingResponse:
|
||||||
|
|
||||||
|
|
||||||
@router.get("/results")
|
@router.get("/results")
|
||||||
def list_results() -> list[dict]:
|
def list_results() -> dict:
|
||||||
if not _MODELS_DIR.exists():
|
if not _MODELS_DIR.exists():
|
||||||
return []
|
return {"results": []}
|
||||||
results = []
|
results = []
|
||||||
for sub in _MODELS_DIR.iterdir():
|
for sub in _MODELS_DIR.iterdir():
|
||||||
if not sub.is_dir():
|
if not sub.is_dir():
|
||||||
|
|
@ -336,4 +336,4 @@ def list_results() -> list[dict]:
|
||||||
results.append(info)
|
results.append(info)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Failed to read training_info.json from %s: %s", info_path, exc)
|
logger.warning("Failed to read training_info.json from %s: %s", info_path, exc)
|
||||||
return results
|
return {"results": results}
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ def _parse_sse(content: bytes) -> list[dict]:
|
||||||
def test_list_jobs_empty(client):
|
def test_list_jobs_empty(client):
|
||||||
r = client.get("/api/train/jobs")
|
r = client.get("/api/train/jobs")
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
assert r.json() == []
|
assert r.json() == {"jobs": []}
|
||||||
|
|
||||||
|
|
||||||
def test_create_job_returns_queued_record(client):
|
def test_create_job_returns_queued_record(client):
|
||||||
|
|
@ -57,7 +57,7 @@ def test_create_job_appears_in_list(client):
|
||||||
client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||||
r = client.get("/api/train/jobs")
|
r = client.get("/api/train/jobs")
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
assert len(r.json()) == 1
|
assert len(r.json()["jobs"]) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_get_job_returns_record(client):
|
def test_get_job_returns_record(client):
|
||||||
|
|
@ -171,7 +171,7 @@ def test_run_unknown_job_returns_404(client):
|
||||||
def test_results_empty_when_no_models_dir(client):
|
def test_results_empty_when_no_models_dir(client):
|
||||||
r = client.get("/api/train/results")
|
r = client.get("/api/train/results")
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
assert r.json() == []
|
assert r.json() == {"results": []}
|
||||||
|
|
||||||
|
|
||||||
def test_results_returns_training_info(client, tmp_path):
|
def test_results_returns_training_info(client, tmp_path):
|
||||||
|
|
@ -184,4 +184,4 @@ def test_results_returns_training_info(client, tmp_path):
|
||||||
r = client.get("/api/train/results")
|
r = client.get("/api/train/results")
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
data = r.json()
|
data = r.json()
|
||||||
assert any(d["name"] == "avocet-deberta-small" for d in data)
|
assert any(d["name"] == "avocet-deberta-small" for d in data["results"])
|
||||||
|
|
|
||||||
|
|
@ -205,7 +205,7 @@ async function submitJob() {
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
type: form.value.type,
|
type: form.value.type,
|
||||||
model_key: form.value.model_key,
|
model_key: form.value.model_key,
|
||||||
config,
|
config_json: config,
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
if (!res.ok) {
|
if (!res.ok) {
|
||||||
|
|
@ -254,7 +254,7 @@ function openLog(id: string) {
|
||||||
closeSSE = useApiSSE(
|
closeSSE = useApiSSE(
|
||||||
`/api/train/jobs/${encodeURIComponent(id)}/run`,
|
`/api/train/jobs/${encodeURIComponent(id)}/run`,
|
||||||
(data) => {
|
(data) => {
|
||||||
if (data.type === 'log' || data.type === 'error') {
|
if (data.type === 'log' || data.type === 'progress' || data.type === 'error') {
|
||||||
logLines.value = [...logLines.value, String(data.message ?? '')]
|
logLines.value = [...logLines.value, String(data.message ?? '')]
|
||||||
nextTick(() => {
|
nextTick(() => {
|
||||||
if (logPanelEl.value) {
|
if (logPanelEl.value) {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue