"""Tests for ssh_targets service — CRUD, validation, serialization.""" from __future__ import annotations import stat import sqlite3 from pathlib import Path import pytest def _make_db(tmp_path: Path) -> Path: """Create a minimal DB with the ssh_targets table via ensure_schema.""" from app.glean.pipeline import ensure_schema db = tmp_path / "test.db" ensure_schema(db) return db def _make_key(tmp_path: Path, mode: int = 0o600) -> Path: """Write a fake SSH private key file with the given permission mode.""" key = tmp_path / "id_ed25519" key.write_text("-----BEGIN OPENSSH PRIVATE KEY-----\nfake\n-----END OPENSSH PRIVATE KEY-----\n") key.chmod(mode) return key # --------------------------------------------------------------------------- # Schema # --------------------------------------------------------------------------- class TestSchema: def test_ssh_targets_table_exists(self, tmp_path): db = _make_db(tmp_path) conn = sqlite3.connect(str(db)) tables = {r[0] for r in conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()} assert "ssh_targets" in tables conn.close() def test_ssh_targets_columns(self, tmp_path): db = _make_db(tmp_path) conn = sqlite3.connect(str(db)) cols = {r[1] for r in conn.execute("PRAGMA table_info(ssh_targets)").fetchall()} assert cols >= {"id", "label", "host", "port", "user", "key_path", "last_tested", "last_ok", "last_error", "created_at", "updated_at"} conn.close() # --------------------------------------------------------------------------- # CRUD # --------------------------------------------------------------------------- class TestCrud: def test_create_and_list(self, tmp_path): from app.services.ssh_targets import create_target, list_targets db = _make_db(tmp_path) key = _make_key(tmp_path) t = create_target(db, label="server-01", host="10.0.0.1", port=22, user="alan", key_path=str(key)) assert t.label == "server-01" assert t.host == "10.0.0.1" assert t.port == 22 assert t.user == "alan" targets = list_targets(db) assert len(targets) == 1 assert targets[0].id == t.id def test_create_resolves_tilde(self, tmp_path): from app.services.ssh_targets import create_target from unittest.mock import patch db = _make_db(tmp_path) key = _make_key(tmp_path) with patch("pathlib.Path.expanduser", return_value=key): t = create_target(db, label="x", host="h", port=22, user="u", key_path="~/id_ed25519") assert "~" not in t.key_path def test_get_returns_none_for_missing(self, tmp_path): from app.services.ssh_targets import get_target db = _make_db(tmp_path) assert get_target(db, "nonexistent-id") is None def test_update_partial(self, tmp_path): from app.services.ssh_targets import create_target, update_target db = _make_db(tmp_path) key = _make_key(tmp_path) t = create_target(db, label="old-label", host="10.0.0.1", port=22, user="alan", key_path=str(key)) updated = update_target(db, t.id, label="new-label") assert updated is not None assert updated.label == "new-label" assert updated.host == "10.0.0.1" # unchanged def test_update_missing_target_returns_none(self, tmp_path): from app.services.ssh_targets import update_target db = _make_db(tmp_path) assert update_target(db, "no-such-id", label="x") is None def test_delete_returns_true_on_success(self, tmp_path): from app.services.ssh_targets import create_target, delete_target, list_targets db = _make_db(tmp_path) key = _make_key(tmp_path) t = create_target(db, label="x", host="h", port=22, user="u", key_path=str(key)) assert delete_target(db, t.id) is True assert list_targets(db) == [] def test_delete_returns_false_for_missing(self, tmp_path): from app.services.ssh_targets import delete_target db = _make_db(tmp_path) assert delete_target(db, "no-such-id") is False def test_list_sorted_by_label(self, tmp_path): from app.services.ssh_targets import create_target, list_targets db = _make_db(tmp_path) key = _make_key(tmp_path) create_target(db, label="zebra", host="h", port=22, user="u", key_path=str(key)) create_target(db, label="alpha", host="h", port=22, user="u", key_path=str(key)) labels = [t.label for t in list_targets(db)] assert labels == ["alpha", "zebra"] # --------------------------------------------------------------------------- # Validation # --------------------------------------------------------------------------- class TestValidation: def test_create_raises_on_missing_key_file(self, tmp_path): from app.services.ssh_targets import create_target db = _make_db(tmp_path) with pytest.raises(ValueError, match="not found"): create_target(db, label="x", host="h", port=22, user="u", key_path="/nonexistent/key") def test_create_raises_on_directory_as_key(self, tmp_path): from app.services.ssh_targets import create_target db = _make_db(tmp_path) with pytest.raises(ValueError, match="not a file"): create_target(db, label="x", host="h", port=22, user="u", key_path=str(tmp_path)) def test_update_raises_on_bad_key_path(self, tmp_path): from app.services.ssh_targets import create_target, update_target db = _make_db(tmp_path) key = _make_key(tmp_path) t = create_target(db, label="x", host="h", port=22, user="u", key_path=str(key)) with pytest.raises(ValueError): update_target(db, t.id, key_path="/does/not/exist") # --------------------------------------------------------------------------- # Key warning # --------------------------------------------------------------------------- class TestKeyWarning: def test_no_warning_for_600(self, tmp_path): from app.services.ssh_targets import key_path_warning key = _make_key(tmp_path, mode=0o600) assert key_path_warning(str(key)) is None def test_warning_for_644(self, tmp_path): from app.services.ssh_targets import key_path_warning key = _make_key(tmp_path, mode=0o644) warning = key_path_warning(str(key)) assert warning is not None assert "chmod 600" in warning def test_no_warning_for_nonexistent_file(self, tmp_path): from app.services.ssh_targets import key_path_warning # Should not raise — just return None result = key_path_warning("/nonexistent/path") assert result is None # --------------------------------------------------------------------------- # Serialization # --------------------------------------------------------------------------- class TestTargetToDict: def test_basic_fields_present(self, tmp_path): from app.services.ssh_targets import create_target, target_to_dict db = _make_db(tmp_path) key = _make_key(tmp_path) t = create_target(db, label="server", host="10.0.0.1", port=2222, user="admin", key_path=str(key)) d = target_to_dict(t) assert d["label"] == "server" assert d["host"] == "10.0.0.1" assert d["port"] == 2222 assert d["user"] == "admin" assert "key_path" in d assert "key_warning" not in d # not included by default def test_key_contents_never_in_dict(self, tmp_path): from app.services.ssh_targets import create_target, target_to_dict db = _make_db(tmp_path) key = _make_key(tmp_path) t = create_target(db, label="x", host="h", port=22, user="u", key_path=str(key)) d = target_to_dict(t, include_warning=True) for v in d.values(): if isinstance(v, str): assert "BEGIN" not in v, "Key contents must never be included in serialized output" def test_include_warning_adds_field(self, tmp_path): from app.services.ssh_targets import create_target, target_to_dict db = _make_db(tmp_path) key = _make_key(tmp_path, mode=0o644) t = create_target(db, label="x", host="h", port=22, user="u", key_path=str(key)) d = target_to_dict(t, include_warning=True) assert "key_warning" in d assert d["key_warning"] is not None def test_last_ok_is_none_before_test(self, tmp_path): from app.services.ssh_targets import create_target, target_to_dict db = _make_db(tmp_path) key = _make_key(tmp_path) t = create_target(db, label="x", host="h", port=22, user="u", key_path=str(key)) d = target_to_dict(t) assert d["last_ok"] is None assert d["last_tested"] is None # --------------------------------------------------------------------------- # test_connection (paramiko not available path) # --------------------------------------------------------------------------- class TestConnectionNoParamiko: def test_returns_error_when_paramiko_missing(self, tmp_path): from app.services.ssh_targets import create_target, test_connection import sys db = _make_db(tmp_path) key = _make_key(tmp_path) t = create_target(db, label="x", host="127.0.0.1", port=22, user="u", key_path=str(key)) # Temporarily hide paramiko from the import system original = sys.modules.get("paramiko") sys.modules["paramiko"] = None # type: ignore[assignment] try: result = test_connection(db, t.id) finally: if original is None: del sys.modules["paramiko"] else: sys.modules["paramiko"] = original assert result["ok"] is False assert "paramiko" in result["error"].lower() def test_raises_key_error_for_missing_target(self, tmp_path): from app.services.ssh_targets import test_connection db = _make_db(tmp_path) with pytest.raises(KeyError): test_connection(db, "no-such-id")