fix: path traversal guard, python_bin config, completed_at on Popen failure

This commit is contained in:
pyr0ball 2026-05-01 23:24:00 -07:00
parent 766fbafa02
commit 32d3436bbd

View file

@ -29,6 +29,7 @@ Testability seam: _DB_PATH global, override via set_db_path().
from __future__ import annotations from __future__ import annotations
import json import json
import logging
import os import os
import sqlite3 import sqlite3
import subprocess as _subprocess import subprocess as _subprocess
@ -38,13 +39,17 @@ from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Generator from typing import Any, Generator
import yaml
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent.parent _ROOT = Path(__file__).parent.parent.parent
_DB_PATH: Path = _ROOT / "data" / "train_jobs.db" _DB_PATH: Path = _ROOT / "data" / "train_jobs.db"
_MODELS_DIR: Path = _ROOT / "models" _MODELS_DIR: Path = _ROOT / "models"
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
_running_procs: dict[str, Any] = {} _running_procs: dict[str, Any] = {}
router = APIRouter() router = APIRouter()
@ -60,6 +65,45 @@ def set_models_dir(path: Path) -> None:
global _MODELS_DIR global _MODELS_DIR
_MODELS_DIR = path _MODELS_DIR = path
def set_config_dir(path: "Path | None") -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
# -- Config helpers ----------------------------------------------------
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_train_config() -> dict:
"""Read python_bin from label_tool.yaml.
Priority (highest to lowest):
1. label_tool.yaml train: python_bin
2. label_tool.yaml cforch: python_bin
3. Hardcoded default (classifiers conda env)
"""
_DEFAULT_PYTHON_BIN = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
f = _config_file()
train_cfg: dict = {}
cforch_cfg: dict = {}
if f.exists():
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
train_cfg = raw.get("train", {}) or {}
cforch_cfg = raw.get("cforch", {}) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse train config %s: %s", f, exc)
return {
"python_bin": train_cfg.get(
"python_bin",
cforch_cfg.get("python_bin", _DEFAULT_PYTHON_BIN),
),
}
# -- Database helpers -------------------------------------------------- # -- Database helpers --------------------------------------------------
@ -182,7 +226,10 @@ def cancel_job(job_id: str) -> dict:
proc.terminate() proc.terminate()
proc.wait(timeout=3) proc.wait(timeout=3)
except _subprocess.TimeoutExpired: except _subprocess.TimeoutExpired:
try:
proc.kill() proc.kill()
except OSError:
pass
return {"status": "cancelled"} return {"status": "cancelled"}
@ -198,7 +245,8 @@ def run_job(job_id: str) -> StreamingResponse:
job = _row_to_dict(row) job = _row_to_dict(row)
def generate(): def generate():
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python" cfg = _load_train_config()
python_bin = cfg["python_bin"]
config = json.loads(job["config_json"] or "{}") config = json.loads(job["config_json"] or "{}")
model_key = job["model_key"] model_key = job["model_key"]
epochs = config.get("epochs", 5) epochs = config.get("epochs", 5)
@ -209,13 +257,14 @@ def run_job(job_id: str) -> StreamingResponse:
data_dir = _ROOT / "data" data_dir = _ROOT / "data"
for sf in config.get("score_files", []): for sf in config.get("score_files", []):
resolved = (data_dir / sf).resolve() resolved = (data_dir / sf).resolve()
if str(resolved).startswith(str(data_dir.resolve())): if resolved.is_relative_to(data_dir.resolve()):
cmd.extend(["--score", str(resolved)]) cmd.extend(["--score", str(resolved)])
elif job["type"] == "llm-sft": elif job["type"] == "llm-sft":
script = str(_ROOT / "scripts" / "finetune_sft.py") script = str(_ROOT / "scripts" / "finetune_sft.py")
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)] cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
else: else:
yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job['type']}'})}\n\n" job_type = job["type"]
yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job_type}'})}\n\n"
return return
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"} proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
@ -260,8 +309,11 @@ def run_job(job_id: str) -> StreamingResponse:
_running_procs.pop(job_id, None) _running_procs.pop(job_id, None)
except Exception as exc: except Exception as exc:
err = str(exc) err = str(exc)
finished_at = datetime.now(timezone.utc).isoformat()
with _db() as conn: with _db() as conn:
conn.execute("UPDATE jobs SET status='failed', error=? WHERE id=?", (err, job_id)) conn.execute(
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
(finished_at, err, job_id))
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n" yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream", return StreamingResponse(generate(), media_type="text/event-stream",
@ -282,6 +334,6 @@ def list_results() -> list[dict]:
try: try:
info = json.loads(info_path.read_text(encoding="utf-8")) info = json.loads(info_path.read_text(encoding="utf-8"))
results.append(info) results.append(info)
except Exception: except Exception as exc:
pass logger.warning("Failed to read training_info.json from %s: %s", info_path, exc)
return results return results