fix(avocet): _MODELS_DIR overridable in tests; sanitize score paths against path traversal
This commit is contained in:
parent
ef8adfb035
commit
60fe1231ce
2 changed files with 23 additions and 11 deletions
19
app/api.py
19
app/api.py
|
|
@ -16,8 +16,9 @@ from fastapi import FastAPI, HTTPException, Query
|
|||
from pydantic import BaseModel
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data" # overridable in tests via set_data_dir()
|
||||
_CONFIG_DIR: Path | None = None # None = use real path
|
||||
_DATA_DIR: Path = _ROOT / "data" # overridable in tests via set_data_dir()
|
||||
_MODELS_DIR: Path = _ROOT / "models" # overridable in tests via set_models_dir()
|
||||
_CONFIG_DIR: Path | None = None # None = use real path
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
|
|
@ -26,6 +27,12 @@ def set_data_dir(path: Path) -> None:
|
|||
_DATA_DIR = path
|
||||
|
||||
|
||||
def set_models_dir(path: Path) -> None:
|
||||
"""Override models directory — used by tests."""
|
||||
global _MODELS_DIR
|
||||
_MODELS_DIR = path
|
||||
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
"""Override config directory — used by tests."""
|
||||
global _CONFIG_DIR
|
||||
|
|
@ -347,7 +354,7 @@ def run_benchmark(include_slow: bool = False):
|
|||
@app.get("/api/finetune/status")
|
||||
def get_finetune_status():
|
||||
"""Scan models/ for training_info.json files. Returns [] if none exist."""
|
||||
models_dir = _ROOT / "models"
|
||||
models_dir = _MODELS_DIR
|
||||
if not models_dir.exists():
|
||||
return []
|
||||
results = []
|
||||
|
|
@ -377,8 +384,12 @@ def run_finetune_endpoint(
|
|||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
script = str(_ROOT / "scripts" / "finetune_classifier.py")
|
||||
cmd = [python_bin, script, "--model", model, "--epochs", str(epochs)]
|
||||
data_root = _DATA_DIR.resolve()
|
||||
for score_file in score:
|
||||
cmd.extend(["--score", score_file])
|
||||
resolved = (_DATA_DIR / score_file).resolve()
|
||||
if not str(resolved).startswith(str(data_root)):
|
||||
raise HTTPException(400, f"Invalid score path: {score_file!r}")
|
||||
cmd.extend(["--score", str(resolved)])
|
||||
|
||||
def generate():
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -336,13 +336,13 @@ def test_finetune_status_returns_empty_when_no_models_dir(client):
|
|||
assert r.json() == []
|
||||
|
||||
|
||||
def test_finetune_status_returns_training_info(client):
|
||||
def test_finetune_status_returns_training_info(client, tmp_path):
|
||||
"""GET /api/finetune/status must return one entry per training_info.json found."""
|
||||
import json as _json
|
||||
from app import api as api_module
|
||||
|
||||
models_dir = api_module._ROOT / "models" / "avocet-deberta-small-test"
|
||||
models_dir.mkdir(parents=True, exist_ok=True)
|
||||
models_dir = tmp_path / "models" / "avocet-deberta-small"
|
||||
models_dir.mkdir(parents=True)
|
||||
info = {
|
||||
"name": "avocet-deberta-small",
|
||||
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
||||
|
|
@ -352,14 +352,14 @@ def test_finetune_status_returns_training_info(client):
|
|||
}
|
||||
(models_dir / "training_info.json").write_text(_json.dumps(info))
|
||||
|
||||
api_module.set_models_dir(tmp_path / "models")
|
||||
try:
|
||||
r = client.get("/api/finetune/status")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert any(d["name"] == "avocet-deberta-small" for d in data)
|
||||
finally:
|
||||
import shutil
|
||||
shutil.rmtree(models_dir)
|
||||
api_module.set_models_dir(api_module._ROOT / "models")
|
||||
|
||||
|
||||
def test_finetune_run_streams_sse_events(client):
|
||||
|
|
@ -427,5 +427,6 @@ def test_finetune_run_passes_score_files_to_subprocess(client):
|
|||
|
||||
assert "--score" in captured_cmd
|
||||
assert captured_cmd.count("--score") == 2
|
||||
assert "run1.jsonl" in captured_cmd
|
||||
assert "run2.jsonl" in captured_cmd
|
||||
# Paths are resolved to absolute — check filenames are present as substrings
|
||||
assert any("run1.jsonl" in arg for arg in captured_cmd)
|
||||
assert any("run2.jsonl" in arg for arg in captured_cmd)
|
||||
|
|
|
|||
Loading…
Reference in a new issue