import json import pytest from pathlib import Path from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives import serialization import jwt as pyjwt from datetime import datetime, timedelta, timezone @pytest.fixture() def test_keys(tmp_path): """Generate test RSA keypair and return (private_pem, public_pem, public_path).""" private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) private_pem = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption(), ) public_pem = private_key.public_key().public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo, ) public_path = tmp_path / "test_public.pem" public_path.write_bytes(public_pem) return private_pem, public_pem, public_path def _make_jwt(private_pem: bytes, tier: str = "paid", product: str = "peregrine", exp_delta_days: int = 30, machine: str = "test-machine") -> str: now = datetime.now(timezone.utc) payload = { "sub": "CFG-PRNG-TEST-TEST-TEST", "product": product, "tier": tier, "seats": 1, "machine": machine, "iat": now, "exp": now + timedelta(days=exp_delta_days), } return pyjwt.encode(payload, private_pem, algorithm="RS256") def _write_license(tmp_path, jwt_token: str, grace_until: str | None = None) -> Path: data = { "jwt": jwt_token, "key_display": "CFG-PRNG-TEST-TEST-TEST", "tier": "paid", "valid_until": None, "machine_id": "test-machine", "last_refresh": datetime.now(timezone.utc).isoformat(), "grace_until": grace_until, } p = tmp_path / "license.json" p.write_text(json.dumps(data)) return p class TestVerifyLocal: def test_valid_jwt_returns_tier(self, test_keys, tmp_path): private_pem, _, public_path = test_keys token = _make_jwt(private_pem) license_path = _write_license(tmp_path, token) from scripts.license import verify_local result = verify_local(license_path=license_path, public_key_path=public_path) assert result is not None assert result["tier"] == "paid" def test_missing_file_returns_none(self, tmp_path): from scripts.license import verify_local result = verify_local(license_path=tmp_path / "missing.json", public_key_path=tmp_path / "key.pem") assert result is None def test_wrong_product_returns_none(self, test_keys, tmp_path): private_pem, _, public_path = test_keys token = _make_jwt(private_pem, product="falcon") license_path = _write_license(tmp_path, token) from scripts.license import verify_local result = verify_local(license_path=license_path, public_key_path=public_path) assert result is None def test_expired_within_grace_returns_tier(self, test_keys, tmp_path): private_pem, _, public_path = test_keys token = _make_jwt(private_pem, exp_delta_days=-1) grace_until = (datetime.now(timezone.utc) + timedelta(days=3)).isoformat() license_path = _write_license(tmp_path, token, grace_until=grace_until) from scripts.license import verify_local result = verify_local(license_path=license_path, public_key_path=public_path) assert result is not None assert result["tier"] == "paid" assert result["in_grace"] is True def test_expired_past_grace_returns_none(self, test_keys, tmp_path): private_pem, _, public_path = test_keys token = _make_jwt(private_pem, exp_delta_days=-10) grace_until = (datetime.now(timezone.utc) - timedelta(days=1)).isoformat() license_path = _write_license(tmp_path, token, grace_until=grace_until) from scripts.license import verify_local result = verify_local(license_path=license_path, public_key_path=public_path) assert result is None class TestEffectiveTier: def test_returns_free_when_no_license(self, tmp_path): from scripts.license import effective_tier result = effective_tier( license_path=tmp_path / "missing.json", public_key_path=tmp_path / "key.pem", ) assert result == "free" def test_returns_tier_from_valid_jwt(self, test_keys, tmp_path): private_pem, _, public_path = test_keys token = _make_jwt(private_pem, tier="premium") license_path = _write_license(tmp_path, token) from scripts.license import effective_tier result = effective_tier(license_path=license_path, public_key_path=public_path) assert result == "premium"