fix: path traversal guard, python_bin config, completed_at on Popen failure
This commit is contained in:
parent
766fbafa02
commit
32d3436bbd
1 changed files with 59 additions and 7 deletions
|
|
@ -29,6 +29,7 @@ Testability seam: _DB_PATH global, override via set_db_path().
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import subprocess as _subprocess
|
import subprocess as _subprocess
|
||||||
|
|
@ -38,13 +39,17 @@ from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Generator
|
from typing import Any, Generator
|
||||||
|
|
||||||
|
import yaml
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_ROOT = Path(__file__).parent.parent.parent
|
_ROOT = Path(__file__).parent.parent.parent
|
||||||
_DB_PATH: Path = _ROOT / "data" / "train_jobs.db"
|
_DB_PATH: Path = _ROOT / "data" / "train_jobs.db"
|
||||||
_MODELS_DIR: Path = _ROOT / "models"
|
_MODELS_DIR: Path = _ROOT / "models"
|
||||||
|
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
|
||||||
_running_procs: dict[str, Any] = {}
|
_running_procs: dict[str, Any] = {}
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
@ -60,6 +65,45 @@ def set_models_dir(path: Path) -> None:
|
||||||
global _MODELS_DIR
|
global _MODELS_DIR
|
||||||
_MODELS_DIR = path
|
_MODELS_DIR = path
|
||||||
|
|
||||||
|
def set_config_dir(path: "Path | None") -> None:
|
||||||
|
global _CONFIG_DIR
|
||||||
|
_CONFIG_DIR = path
|
||||||
|
|
||||||
|
|
||||||
|
# -- Config helpers ----------------------------------------------------
|
||||||
|
|
||||||
|
def _config_file() -> Path:
|
||||||
|
if _CONFIG_DIR is not None:
|
||||||
|
return _CONFIG_DIR / "label_tool.yaml"
|
||||||
|
return _ROOT / "config" / "label_tool.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_train_config() -> dict:
|
||||||
|
"""Read python_bin from label_tool.yaml.
|
||||||
|
|
||||||
|
Priority (highest to lowest):
|
||||||
|
1. label_tool.yaml train: python_bin
|
||||||
|
2. label_tool.yaml cforch: python_bin
|
||||||
|
3. Hardcoded default (classifiers conda env)
|
||||||
|
"""
|
||||||
|
_DEFAULT_PYTHON_BIN = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||||
|
f = _config_file()
|
||||||
|
train_cfg: dict = {}
|
||||||
|
cforch_cfg: dict = {}
|
||||||
|
if f.exists():
|
||||||
|
try:
|
||||||
|
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||||
|
train_cfg = raw.get("train", {}) or {}
|
||||||
|
cforch_cfg = raw.get("cforch", {}) or {}
|
||||||
|
except yaml.YAMLError as exc:
|
||||||
|
logger.warning("Failed to parse train config %s: %s", f, exc)
|
||||||
|
return {
|
||||||
|
"python_bin": train_cfg.get(
|
||||||
|
"python_bin",
|
||||||
|
cforch_cfg.get("python_bin", _DEFAULT_PYTHON_BIN),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# -- Database helpers --------------------------------------------------
|
# -- Database helpers --------------------------------------------------
|
||||||
|
|
||||||
|
|
@ -182,7 +226,10 @@ def cancel_job(job_id: str) -> dict:
|
||||||
proc.terminate()
|
proc.terminate()
|
||||||
proc.wait(timeout=3)
|
proc.wait(timeout=3)
|
||||||
except _subprocess.TimeoutExpired:
|
except _subprocess.TimeoutExpired:
|
||||||
|
try:
|
||||||
proc.kill()
|
proc.kill()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
return {"status": "cancelled"}
|
return {"status": "cancelled"}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -198,7 +245,8 @@ def run_job(job_id: str) -> StreamingResponse:
|
||||||
job = _row_to_dict(row)
|
job = _row_to_dict(row)
|
||||||
|
|
||||||
def generate():
|
def generate():
|
||||||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
cfg = _load_train_config()
|
||||||
|
python_bin = cfg["python_bin"]
|
||||||
config = json.loads(job["config_json"] or "{}")
|
config = json.loads(job["config_json"] or "{}")
|
||||||
model_key = job["model_key"]
|
model_key = job["model_key"]
|
||||||
epochs = config.get("epochs", 5)
|
epochs = config.get("epochs", 5)
|
||||||
|
|
@ -209,13 +257,14 @@ def run_job(job_id: str) -> StreamingResponse:
|
||||||
data_dir = _ROOT / "data"
|
data_dir = _ROOT / "data"
|
||||||
for sf in config.get("score_files", []):
|
for sf in config.get("score_files", []):
|
||||||
resolved = (data_dir / sf).resolve()
|
resolved = (data_dir / sf).resolve()
|
||||||
if str(resolved).startswith(str(data_dir.resolve())):
|
if resolved.is_relative_to(data_dir.resolve()):
|
||||||
cmd.extend(["--score", str(resolved)])
|
cmd.extend(["--score", str(resolved)])
|
||||||
elif job["type"] == "llm-sft":
|
elif job["type"] == "llm-sft":
|
||||||
script = str(_ROOT / "scripts" / "finetune_sft.py")
|
script = str(_ROOT / "scripts" / "finetune_sft.py")
|
||||||
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
|
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
|
||||||
else:
|
else:
|
||||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job['type']}'})}\n\n"
|
job_type = job["type"]
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job_type}'})}\n\n"
|
||||||
return
|
return
|
||||||
|
|
||||||
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
|
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
|
||||||
|
|
@ -260,8 +309,11 @@ def run_job(job_id: str) -> StreamingResponse:
|
||||||
_running_procs.pop(job_id, None)
|
_running_procs.pop(job_id, None)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
err = str(exc)
|
err = str(exc)
|
||||||
|
finished_at = datetime.now(timezone.utc).isoformat()
|
||||||
with _db() as conn:
|
with _db() as conn:
|
||||||
conn.execute("UPDATE jobs SET status='failed', error=? WHERE id=?", (err, job_id))
|
conn.execute(
|
||||||
|
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
|
||||||
|
(finished_at, err, job_id))
|
||||||
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
|
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
|
||||||
|
|
||||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||||
|
|
@ -282,6 +334,6 @@ def list_results() -> list[dict]:
|
||||||
try:
|
try:
|
||||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||||
results.append(info)
|
results.append(info)
|
||||||
except Exception:
|
except Exception as exc:
|
||||||
pass
|
logger.warning("Failed to read training_info.json from %s: %s", info_path, exc)
|
||||||
return results
|
return results
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue