fix: path traversal guard, python_bin config, completed_at on Popen failure
This commit is contained in:
parent
766fbafa02
commit
32d3436bbd
1 changed files with 59 additions and 7 deletions
|
|
@ -29,6 +29,7 @@ Testability seam: _DB_PATH global, override via set_db_path().
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import subprocess as _subprocess
|
||||
|
|
@ -38,13 +39,17 @@ from datetime import datetime, timezone
|
|||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DB_PATH: Path = _ROOT / "data" / "train_jobs.db"
|
||||
_MODELS_DIR: Path = _ROOT / "models"
|
||||
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
|
||||
_running_procs: dict[str, Any] = {}
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -60,6 +65,45 @@ def set_models_dir(path: Path) -> None:
|
|||
global _MODELS_DIR
|
||||
_MODELS_DIR = path
|
||||
|
||||
def set_config_dir(path: "Path | None") -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# -- Config helpers ----------------------------------------------------
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _load_train_config() -> dict:
|
||||
"""Read python_bin from label_tool.yaml.
|
||||
|
||||
Priority (highest to lowest):
|
||||
1. label_tool.yaml train: python_bin
|
||||
2. label_tool.yaml cforch: python_bin
|
||||
3. Hardcoded default (classifiers conda env)
|
||||
"""
|
||||
_DEFAULT_PYTHON_BIN = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
f = _config_file()
|
||||
train_cfg: dict = {}
|
||||
cforch_cfg: dict = {}
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
train_cfg = raw.get("train", {}) or {}
|
||||
cforch_cfg = raw.get("cforch", {}) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse train config %s: %s", f, exc)
|
||||
return {
|
||||
"python_bin": train_cfg.get(
|
||||
"python_bin",
|
||||
cforch_cfg.get("python_bin", _DEFAULT_PYTHON_BIN),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# -- Database helpers --------------------------------------------------
|
||||
|
||||
|
|
@ -182,7 +226,10 @@ def cancel_job(job_id: str) -> dict:
|
|||
proc.terminate()
|
||||
proc.wait(timeout=3)
|
||||
except _subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
try:
|
||||
proc.kill()
|
||||
except OSError:
|
||||
pass
|
||||
return {"status": "cancelled"}
|
||||
|
||||
|
||||
|
|
@ -198,7 +245,8 @@ def run_job(job_id: str) -> StreamingResponse:
|
|||
job = _row_to_dict(row)
|
||||
|
||||
def generate():
|
||||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
cfg = _load_train_config()
|
||||
python_bin = cfg["python_bin"]
|
||||
config = json.loads(job["config_json"] or "{}")
|
||||
model_key = job["model_key"]
|
||||
epochs = config.get("epochs", 5)
|
||||
|
|
@ -209,13 +257,14 @@ def run_job(job_id: str) -> StreamingResponse:
|
|||
data_dir = _ROOT / "data"
|
||||
for sf in config.get("score_files", []):
|
||||
resolved = (data_dir / sf).resolve()
|
||||
if str(resolved).startswith(str(data_dir.resolve())):
|
||||
if resolved.is_relative_to(data_dir.resolve()):
|
||||
cmd.extend(["--score", str(resolved)])
|
||||
elif job["type"] == "llm-sft":
|
||||
script = str(_ROOT / "scripts" / "finetune_sft.py")
|
||||
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job['type']}'})}\n\n"
|
||||
job_type = job["type"]
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job_type}'})}\n\n"
|
||||
return
|
||||
|
||||
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
|
||||
|
|
@ -260,8 +309,11 @@ def run_job(job_id: str) -> StreamingResponse:
|
|||
_running_procs.pop(job_id, None)
|
||||
except Exception as exc:
|
||||
err = str(exc)
|
||||
finished_at = datetime.now(timezone.utc).isoformat()
|
||||
with _db() as conn:
|
||||
conn.execute("UPDATE jobs SET status='failed', error=? WHERE id=?", (err, job_id))
|
||||
conn.execute(
|
||||
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
|
||||
(finished_at, err, job_id))
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||
|
|
@ -282,6 +334,6 @@ def list_results() -> list[dict]:
|
|||
try:
|
||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||
results.append(info)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read training_info.json from %s: %s", info_path, exc)
|
||||
return results
|
||||
|
|
|
|||
Loading…
Reference in a new issue