peregrine/tests/test_license.py

121 lines
4.8 KiB
Python

import json
import pytest
from pathlib import Path
from unittest.mock import patch, MagicMock
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
import jwt as pyjwt
from datetime import datetime, timedelta, timezone
@pytest.fixture()
def test_keys(tmp_path):
"""Generate test RSA keypair and return (private_pem, public_pem, public_path)."""
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
private_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
public_pem = private_key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_path = tmp_path / "test_public.pem"
public_path.write_bytes(public_pem)
return private_pem, public_pem, public_path
def _make_jwt(private_pem: bytes, tier: str = "paid",
product: str = "peregrine",
exp_delta_days: int = 30,
machine: str = "test-machine") -> str:
now = datetime.now(timezone.utc)
payload = {
"sub": "CFG-PRNG-TEST-TEST-TEST",
"product": product,
"tier": tier,
"seats": 1,
"machine": machine,
"iat": now,
"exp": now + timedelta(days=exp_delta_days),
}
return pyjwt.encode(payload, private_pem, algorithm="RS256")
def _write_license(tmp_path, jwt_token: str, grace_until: str | None = None) -> Path:
data = {
"jwt": jwt_token,
"key_display": "CFG-PRNG-TEST-TEST-TEST",
"tier": "paid",
"valid_until": None,
"machine_id": "test-machine",
"last_refresh": datetime.now(timezone.utc).isoformat(),
"grace_until": grace_until,
}
p = tmp_path / "license.json"
p.write_text(json.dumps(data))
return p
class TestVerifyLocal:
def test_valid_jwt_returns_tier(self, test_keys, tmp_path):
private_pem, _, public_path = test_keys
token = _make_jwt(private_pem)
license_path = _write_license(tmp_path, token)
from scripts.license import verify_local
result = verify_local(license_path=license_path, public_key_path=public_path)
assert result is not None
assert result["tier"] == "paid"
def test_missing_file_returns_none(self, tmp_path):
from scripts.license import verify_local
result = verify_local(license_path=tmp_path / "missing.json",
public_key_path=tmp_path / "key.pem")
assert result is None
def test_wrong_product_returns_none(self, test_keys, tmp_path):
private_pem, _, public_path = test_keys
token = _make_jwt(private_pem, product="falcon")
license_path = _write_license(tmp_path, token)
from scripts.license import verify_local
result = verify_local(license_path=license_path, public_key_path=public_path)
assert result is None
def test_expired_within_grace_returns_tier(self, test_keys, tmp_path):
private_pem, _, public_path = test_keys
token = _make_jwt(private_pem, exp_delta_days=-1)
grace_until = (datetime.now(timezone.utc) + timedelta(days=3)).isoformat()
license_path = _write_license(tmp_path, token, grace_until=grace_until)
from scripts.license import verify_local
result = verify_local(license_path=license_path, public_key_path=public_path)
assert result is not None
assert result["tier"] == "paid"
assert result["in_grace"] is True
def test_expired_past_grace_returns_none(self, test_keys, tmp_path):
private_pem, _, public_path = test_keys
token = _make_jwt(private_pem, exp_delta_days=-10)
grace_until = (datetime.now(timezone.utc) - timedelta(days=1)).isoformat()
license_path = _write_license(tmp_path, token, grace_until=grace_until)
from scripts.license import verify_local
result = verify_local(license_path=license_path, public_key_path=public_path)
assert result is None
class TestEffectiveTier:
def test_returns_free_when_no_license(self, tmp_path):
from scripts.license import effective_tier
result = effective_tier(
license_path=tmp_path / "missing.json",
public_key_path=tmp_path / "key.pem",
)
assert result == "free"
def test_returns_tier_from_valid_jwt(self, test_keys, tmp_path):
private_pem, _, public_path = test_keys
token = _make_jwt(private_pem, tier="premium")
license_path = _write_license(tmp_path, token)
from scripts.license import effective_tier
result = effective_tier(license_path=license_path, public_key_path=public_path)
assert result == "premium"