peregrine/tests/test_license.py
pyr0ball e87c707dd9
Some checks failed
CI / Backend (Python) (push) Failing after 30s
CI / Frontend (Vue) (push) Successful in 22s
CI / Backend (Python) (pull_request) Failing after 27s
CI / Frontend (Vue) (pull_request) Successful in 20s
chore(lint): ruff auto-fix unused imports in tests/
Removes unused imports flagged by ruff F401 across 47 test files.
Auto-fix only — imports verified unused by static analysis.
2026-05-20 23:07:52 -07:00

120 lines
4.8 KiB
Python

import json
import pytest
from pathlib import Path
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
import jwt as pyjwt
from datetime import datetime, timedelta, timezone
@pytest.fixture()
def test_keys(tmp_path):
"""Generate test RSA keypair and return (private_pem, public_pem, public_path)."""
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
private_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
public_pem = private_key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_path = tmp_path / "test_public.pem"
public_path.write_bytes(public_pem)
return private_pem, public_pem, public_path
def _make_jwt(private_pem: bytes, tier: str = "paid",
product: str = "peregrine",
exp_delta_days: int = 30,
machine: str = "test-machine") -> str:
now = datetime.now(timezone.utc)
payload = {
"sub": "CFG-PRNG-TEST-TEST-TEST",
"product": product,
"tier": tier,
"seats": 1,
"machine": machine,
"iat": now,
"exp": now + timedelta(days=exp_delta_days),
}
return pyjwt.encode(payload, private_pem, algorithm="RS256")
def _write_license(tmp_path, jwt_token: str, grace_until: str | None = None) -> Path:
data = {
"jwt": jwt_token,
"key_display": "CFG-PRNG-TEST-TEST-TEST",
"tier": "paid",
"valid_until": None,
"machine_id": "test-machine",
"last_refresh": datetime.now(timezone.utc).isoformat(),
"grace_until": grace_until,
}
p = tmp_path / "license.json"
p.write_text(json.dumps(data))
return p
class TestVerifyLocal:
def test_valid_jwt_returns_tier(self, test_keys, tmp_path):
private_pem, _, public_path = test_keys
token = _make_jwt(private_pem)
license_path = _write_license(tmp_path, token)
from scripts.license import verify_local
result = verify_local(license_path=license_path, public_key_path=public_path)
assert result is not None
assert result["tier"] == "paid"
def test_missing_file_returns_none(self, tmp_path):
from scripts.license import verify_local
result = verify_local(license_path=tmp_path / "missing.json",
public_key_path=tmp_path / "key.pem")
assert result is None
def test_wrong_product_returns_none(self, test_keys, tmp_path):
private_pem, _, public_path = test_keys
token = _make_jwt(private_pem, product="falcon")
license_path = _write_license(tmp_path, token)
from scripts.license import verify_local
result = verify_local(license_path=license_path, public_key_path=public_path)
assert result is None
def test_expired_within_grace_returns_tier(self, test_keys, tmp_path):
private_pem, _, public_path = test_keys
token = _make_jwt(private_pem, exp_delta_days=-1)
grace_until = (datetime.now(timezone.utc) + timedelta(days=3)).isoformat()
license_path = _write_license(tmp_path, token, grace_until=grace_until)
from scripts.license import verify_local
result = verify_local(license_path=license_path, public_key_path=public_path)
assert result is not None
assert result["tier"] == "paid"
assert result["in_grace"] is True
def test_expired_past_grace_returns_none(self, test_keys, tmp_path):
private_pem, _, public_path = test_keys
token = _make_jwt(private_pem, exp_delta_days=-10)
grace_until = (datetime.now(timezone.utc) - timedelta(days=1)).isoformat()
license_path = _write_license(tmp_path, token, grace_until=grace_until)
from scripts.license import verify_local
result = verify_local(license_path=license_path, public_key_path=public_path)
assert result is None
class TestEffectiveTier:
def test_returns_free_when_no_license(self, tmp_path):
from scripts.license import effective_tier
result = effective_tier(
license_path=tmp_path / "missing.json",
public_key_path=tmp_path / "key.pem",
)
assert result == "free"
def test_returns_tier_from_valid_jwt(self, test_keys, tmp_path):
private_pem, _, public_path = test_keys
token = _make_jwt(private_pem, tier="premium")
license_path = _write_license(tmp_path, token)
from scripts.license import effective_tier
result = effective_tier(license_path=license_path, public_key_path=public_path)
assert result == "premium"