fix: path traversal guard, python_bin config, completed_at on Popen failure

This commit is contained in:
pyr0ball 2026-05-01 23:24:00 -07:00
parent 766fbafa02
commit 32d3436bbd

View file

@ -29,6 +29,7 @@ Testability seam: _DB_PATH global, override via set_db_path().
from __future__ import annotations
import json
import logging
import os
import sqlite3
import subprocess as _subprocess
@ -38,13 +39,17 @@ from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Generator
import yaml
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent.parent
_DB_PATH: Path = _ROOT / "data" / "train_jobs.db"
_MODELS_DIR: Path = _ROOT / "models"
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
_running_procs: dict[str, Any] = {}
router = APIRouter()
@ -60,6 +65,45 @@ def set_models_dir(path: Path) -> None:
global _MODELS_DIR
_MODELS_DIR = path
def set_config_dir(path: "Path | None") -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
# -- Config helpers ----------------------------------------------------
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_train_config() -> dict:
"""Read python_bin from label_tool.yaml.
Priority (highest to lowest):
1. label_tool.yaml train: python_bin
2. label_tool.yaml cforch: python_bin
3. Hardcoded default (classifiers conda env)
"""
_DEFAULT_PYTHON_BIN = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
f = _config_file()
train_cfg: dict = {}
cforch_cfg: dict = {}
if f.exists():
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
train_cfg = raw.get("train", {}) or {}
cforch_cfg = raw.get("cforch", {}) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse train config %s: %s", f, exc)
return {
"python_bin": train_cfg.get(
"python_bin",
cforch_cfg.get("python_bin", _DEFAULT_PYTHON_BIN),
),
}
# -- Database helpers --------------------------------------------------
@ -182,7 +226,10 @@ def cancel_job(job_id: str) -> dict:
proc.terminate()
proc.wait(timeout=3)
except _subprocess.TimeoutExpired:
try:
proc.kill()
except OSError:
pass
return {"status": "cancelled"}
@ -198,7 +245,8 @@ def run_job(job_id: str) -> StreamingResponse:
job = _row_to_dict(row)
def generate():
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
cfg = _load_train_config()
python_bin = cfg["python_bin"]
config = json.loads(job["config_json"] or "{}")
model_key = job["model_key"]
epochs = config.get("epochs", 5)
@ -209,13 +257,14 @@ def run_job(job_id: str) -> StreamingResponse:
data_dir = _ROOT / "data"
for sf in config.get("score_files", []):
resolved = (data_dir / sf).resolve()
if str(resolved).startswith(str(data_dir.resolve())):
if resolved.is_relative_to(data_dir.resolve()):
cmd.extend(["--score", str(resolved)])
elif job["type"] == "llm-sft":
script = str(_ROOT / "scripts" / "finetune_sft.py")
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
else:
yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job['type']}'})}\n\n"
job_type = job["type"]
yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job_type}'})}\n\n"
return
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
@ -260,8 +309,11 @@ def run_job(job_id: str) -> StreamingResponse:
_running_procs.pop(job_id, None)
except Exception as exc:
err = str(exc)
finished_at = datetime.now(timezone.utc).isoformat()
with _db() as conn:
conn.execute("UPDATE jobs SET status='failed', error=? WHERE id=?", (err, job_id))
conn.execute(
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
(finished_at, err, job_id))
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream",
@ -282,6 +334,6 @@ def list_results() -> list[dict]:
try:
info = json.loads(info_path.read_text(encoding="utf-8"))
results.append(info)
except Exception:
pass
except Exception as exc:
logger.warning("Failed to read training_info.json from %s: %s", info_path, exc)
return results