fix(avocet): _MODELS_DIR overridable in tests; sanitize score paths against path traversal

This commit is contained in:
pyr0ball 2026-03-15 16:07:27 -07:00
parent ef8adfb035
commit 60fe1231ce
2 changed files with 23 additions and 11 deletions

View file

@ -16,8 +16,9 @@ from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel from pydantic import BaseModel
_ROOT = Path(__file__).parent.parent _ROOT = Path(__file__).parent.parent
_DATA_DIR: Path = _ROOT / "data" # overridable in tests via set_data_dir() _DATA_DIR: Path = _ROOT / "data" # overridable in tests via set_data_dir()
_CONFIG_DIR: Path | None = None # None = use real path _MODELS_DIR: Path = _ROOT / "models" # overridable in tests via set_models_dir()
_CONFIG_DIR: Path | None = None # None = use real path
def set_data_dir(path: Path) -> None: def set_data_dir(path: Path) -> None:
@ -26,6 +27,12 @@ def set_data_dir(path: Path) -> None:
_DATA_DIR = path _DATA_DIR = path
def set_models_dir(path: Path) -> None:
"""Override models directory — used by tests."""
global _MODELS_DIR
_MODELS_DIR = path
def set_config_dir(path: Path | None) -> None: def set_config_dir(path: Path | None) -> None:
"""Override config directory — used by tests.""" """Override config directory — used by tests."""
global _CONFIG_DIR global _CONFIG_DIR
@ -347,7 +354,7 @@ def run_benchmark(include_slow: bool = False):
@app.get("/api/finetune/status") @app.get("/api/finetune/status")
def get_finetune_status(): def get_finetune_status():
"""Scan models/ for training_info.json files. Returns [] if none exist.""" """Scan models/ for training_info.json files. Returns [] if none exist."""
models_dir = _ROOT / "models" models_dir = _MODELS_DIR
if not models_dir.exists(): if not models_dir.exists():
return [] return []
results = [] results = []
@ -377,8 +384,12 @@ def run_finetune_endpoint(
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python" python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
script = str(_ROOT / "scripts" / "finetune_classifier.py") script = str(_ROOT / "scripts" / "finetune_classifier.py")
cmd = [python_bin, script, "--model", model, "--epochs", str(epochs)] cmd = [python_bin, script, "--model", model, "--epochs", str(epochs)]
data_root = _DATA_DIR.resolve()
for score_file in score: for score_file in score:
cmd.extend(["--score", score_file]) resolved = (_DATA_DIR / score_file).resolve()
if not str(resolved).startswith(str(data_root)):
raise HTTPException(400, f"Invalid score path: {score_file!r}")
cmd.extend(["--score", str(resolved)])
def generate(): def generate():
try: try:

View file

@ -336,13 +336,13 @@ def test_finetune_status_returns_empty_when_no_models_dir(client):
assert r.json() == [] assert r.json() == []
def test_finetune_status_returns_training_info(client): def test_finetune_status_returns_training_info(client, tmp_path):
"""GET /api/finetune/status must return one entry per training_info.json found.""" """GET /api/finetune/status must return one entry per training_info.json found."""
import json as _json import json as _json
from app import api as api_module from app import api as api_module
models_dir = api_module._ROOT / "models" / "avocet-deberta-small-test" models_dir = tmp_path / "models" / "avocet-deberta-small"
models_dir.mkdir(parents=True, exist_ok=True) models_dir.mkdir(parents=True)
info = { info = {
"name": "avocet-deberta-small", "name": "avocet-deberta-small",
"base_model_id": "cross-encoder/nli-deberta-v3-small", "base_model_id": "cross-encoder/nli-deberta-v3-small",
@ -352,14 +352,14 @@ def test_finetune_status_returns_training_info(client):
} }
(models_dir / "training_info.json").write_text(_json.dumps(info)) (models_dir / "training_info.json").write_text(_json.dumps(info))
api_module.set_models_dir(tmp_path / "models")
try: try:
r = client.get("/api/finetune/status") r = client.get("/api/finetune/status")
assert r.status_code == 200 assert r.status_code == 200
data = r.json() data = r.json()
assert any(d["name"] == "avocet-deberta-small" for d in data) assert any(d["name"] == "avocet-deberta-small" for d in data)
finally: finally:
import shutil api_module.set_models_dir(api_module._ROOT / "models")
shutil.rmtree(models_dir)
def test_finetune_run_streams_sse_events(client): def test_finetune_run_streams_sse_events(client):
@ -427,5 +427,6 @@ def test_finetune_run_passes_score_files_to_subprocess(client):
assert "--score" in captured_cmd assert "--score" in captured_cmd
assert captured_cmd.count("--score") == 2 assert captured_cmd.count("--score") == 2
assert "run1.jsonl" in captured_cmd # Paths are resolved to absolute — check filenames are present as substrings
assert "run2.jsonl" in captured_cmd assert any("run1.jsonl" in arg for arg in captured_cmd)
assert any("run2.jsonl" in arg for arg in captured_cmd)