121 lines
4.8 KiB
Python
121 lines
4.8 KiB
Python
import json
|
|
import pytest
|
|
from pathlib import Path
|
|
from unittest.mock import patch, MagicMock
|
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
from cryptography.hazmat.primitives import serialization
|
|
import jwt as pyjwt
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
|
|
@pytest.fixture()
|
|
def test_keys(tmp_path):
|
|
"""Generate test RSA keypair and return (private_pem, public_pem, public_path)."""
|
|
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
|
private_pem = private_key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
|
encryption_algorithm=serialization.NoEncryption(),
|
|
)
|
|
public_pem = private_key.public_key().public_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
|
)
|
|
public_path = tmp_path / "test_public.pem"
|
|
public_path.write_bytes(public_pem)
|
|
return private_pem, public_pem, public_path
|
|
|
|
|
|
def _make_jwt(private_pem: bytes, tier: str = "paid",
|
|
product: str = "peregrine",
|
|
exp_delta_days: int = 30,
|
|
machine: str = "test-machine") -> str:
|
|
now = datetime.now(timezone.utc)
|
|
payload = {
|
|
"sub": "CFG-PRNG-TEST-TEST-TEST",
|
|
"product": product,
|
|
"tier": tier,
|
|
"seats": 1,
|
|
"machine": machine,
|
|
"iat": now,
|
|
"exp": now + timedelta(days=exp_delta_days),
|
|
}
|
|
return pyjwt.encode(payload, private_pem, algorithm="RS256")
|
|
|
|
|
|
def _write_license(tmp_path, jwt_token: str, grace_until: str | None = None) -> Path:
|
|
data = {
|
|
"jwt": jwt_token,
|
|
"key_display": "CFG-PRNG-TEST-TEST-TEST",
|
|
"tier": "paid",
|
|
"valid_until": None,
|
|
"machine_id": "test-machine",
|
|
"last_refresh": datetime.now(timezone.utc).isoformat(),
|
|
"grace_until": grace_until,
|
|
}
|
|
p = tmp_path / "license.json"
|
|
p.write_text(json.dumps(data))
|
|
return p
|
|
|
|
|
|
class TestVerifyLocal:
|
|
def test_valid_jwt_returns_tier(self, test_keys, tmp_path):
|
|
private_pem, _, public_path = test_keys
|
|
token = _make_jwt(private_pem)
|
|
license_path = _write_license(tmp_path, token)
|
|
from scripts.license import verify_local
|
|
result = verify_local(license_path=license_path, public_key_path=public_path)
|
|
assert result is not None
|
|
assert result["tier"] == "paid"
|
|
|
|
def test_missing_file_returns_none(self, tmp_path):
|
|
from scripts.license import verify_local
|
|
result = verify_local(license_path=tmp_path / "missing.json",
|
|
public_key_path=tmp_path / "key.pem")
|
|
assert result is None
|
|
|
|
def test_wrong_product_returns_none(self, test_keys, tmp_path):
|
|
private_pem, _, public_path = test_keys
|
|
token = _make_jwt(private_pem, product="falcon")
|
|
license_path = _write_license(tmp_path, token)
|
|
from scripts.license import verify_local
|
|
result = verify_local(license_path=license_path, public_key_path=public_path)
|
|
assert result is None
|
|
|
|
def test_expired_within_grace_returns_tier(self, test_keys, tmp_path):
|
|
private_pem, _, public_path = test_keys
|
|
token = _make_jwt(private_pem, exp_delta_days=-1)
|
|
grace_until = (datetime.now(timezone.utc) + timedelta(days=3)).isoformat()
|
|
license_path = _write_license(tmp_path, token, grace_until=grace_until)
|
|
from scripts.license import verify_local
|
|
result = verify_local(license_path=license_path, public_key_path=public_path)
|
|
assert result is not None
|
|
assert result["tier"] == "paid"
|
|
assert result["in_grace"] is True
|
|
|
|
def test_expired_past_grace_returns_none(self, test_keys, tmp_path):
|
|
private_pem, _, public_path = test_keys
|
|
token = _make_jwt(private_pem, exp_delta_days=-10)
|
|
grace_until = (datetime.now(timezone.utc) - timedelta(days=1)).isoformat()
|
|
license_path = _write_license(tmp_path, token, grace_until=grace_until)
|
|
from scripts.license import verify_local
|
|
result = verify_local(license_path=license_path, public_key_path=public_path)
|
|
assert result is None
|
|
|
|
|
|
class TestEffectiveTier:
|
|
def test_returns_free_when_no_license(self, tmp_path):
|
|
from scripts.license import effective_tier
|
|
result = effective_tier(
|
|
license_path=tmp_path / "missing.json",
|
|
public_key_path=tmp_path / "key.pem",
|
|
)
|
|
assert result == "free"
|
|
|
|
def test_returns_tier_from_valid_jwt(self, test_keys, tmp_path):
|
|
private_pem, _, public_path = test_keys
|
|
token = _make_jwt(private_pem, tier="premium")
|
|
license_path = _write_license(tmp_path, token)
|
|
from scripts.license import effective_tier
|
|
result = effective_tier(license_path=license_path, public_key_path=public_path)
|
|
assert result == "premium"
|