From 32d3436bbd6f2590e2693429eaaa46fcafb3a61a Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Fri, 1 May 2026 23:24:00 -0700 Subject: [PATCH] fix: path traversal guard, python_bin config, completed_at on Popen failure --- app/train/train.py | 66 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 59 insertions(+), 7 deletions(-) diff --git a/app/train/train.py b/app/train/train.py index 9dfedd4..477be67 100644 --- a/app/train/train.py +++ b/app/train/train.py @@ -29,6 +29,7 @@ Testability seam: _DB_PATH global, override via set_db_path(). from __future__ import annotations import json +import logging import os import sqlite3 import subprocess as _subprocess @@ -38,13 +39,17 @@ from datetime import datetime, timezone from pathlib import Path from typing import Any, Generator +import yaml from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel +logger = logging.getLogger(__name__) + _ROOT = Path(__file__).parent.parent.parent _DB_PATH: Path = _ROOT / "data" / "train_jobs.db" _MODELS_DIR: Path = _ROOT / "models" +_CONFIG_DIR: Path | None = None # override in tests via set_config_dir() _running_procs: dict[str, Any] = {} router = APIRouter() @@ -60,6 +65,45 @@ def set_models_dir(path: Path) -> None: global _MODELS_DIR _MODELS_DIR = path +def set_config_dir(path: "Path | None") -> None: + global _CONFIG_DIR + _CONFIG_DIR = path + + +# -- Config helpers ---------------------------------------------------- + +def _config_file() -> Path: + if _CONFIG_DIR is not None: + return _CONFIG_DIR / "label_tool.yaml" + return _ROOT / "config" / "label_tool.yaml" + + +def _load_train_config() -> dict: + """Read python_bin from label_tool.yaml. + + Priority (highest to lowest): + 1. label_tool.yaml train: python_bin + 2. label_tool.yaml cforch: python_bin + 3. Hardcoded default (classifiers conda env) + """ + _DEFAULT_PYTHON_BIN = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python" + f = _config_file() + train_cfg: dict = {} + cforch_cfg: dict = {} + if f.exists(): + try: + raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {} + train_cfg = raw.get("train", {}) or {} + cforch_cfg = raw.get("cforch", {}) or {} + except yaml.YAMLError as exc: + logger.warning("Failed to parse train config %s: %s", f, exc) + return { + "python_bin": train_cfg.get( + "python_bin", + cforch_cfg.get("python_bin", _DEFAULT_PYTHON_BIN), + ), + } + # -- Database helpers -------------------------------------------------- @@ -182,7 +226,10 @@ def cancel_job(job_id: str) -> dict: proc.terminate() proc.wait(timeout=3) except _subprocess.TimeoutExpired: - proc.kill() + try: + proc.kill() + except OSError: + pass return {"status": "cancelled"} @@ -198,7 +245,8 @@ def run_job(job_id: str) -> StreamingResponse: job = _row_to_dict(row) def generate(): - python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python" + cfg = _load_train_config() + python_bin = cfg["python_bin"] config = json.loads(job["config_json"] or "{}") model_key = job["model_key"] epochs = config.get("epochs", 5) @@ -209,13 +257,14 @@ def run_job(job_id: str) -> StreamingResponse: data_dir = _ROOT / "data" for sf in config.get("score_files", []): resolved = (data_dir / sf).resolve() - if str(resolved).startswith(str(data_dir.resolve())): + if resolved.is_relative_to(data_dir.resolve()): cmd.extend(["--score", str(resolved)]) elif job["type"] == "llm-sft": script = str(_ROOT / "scripts" / "finetune_sft.py") cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)] else: - yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job['type']}'})}\n\n" + job_type = job["type"] + yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job_type}'})}\n\n" return proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"} @@ -260,8 +309,11 @@ def run_job(job_id: str) -> StreamingResponse: _running_procs.pop(job_id, None) except Exception as exc: err = str(exc) + finished_at = datetime.now(timezone.utc).isoformat() with _db() as conn: - conn.execute("UPDATE jobs SET status='failed', error=? WHERE id=?", (err, job_id)) + conn.execute( + "UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?", + (finished_at, err, job_id)) yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n" return StreamingResponse(generate(), media_type="text/event-stream", @@ -282,6 +334,6 @@ def list_results() -> list[dict]: try: info = json.loads(info_path.read_text(encoding="utf-8")) results.append(info) - except Exception: - pass + except Exception as exc: + logger.warning("Failed to read training_info.json from %s: %s", info_path, exc) return results