Compare commits
3 commits
d3ab3fa460
...
4331943cb4
| Author | SHA1 | Date | |
|---|---|---|---|
| 4331943cb4 | |||
| 4903f30c82 | |||
| b7ef804cf7 |
8 changed files with 560 additions and 2 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -29,3 +29,4 @@ scrapers/.debug/
|
|||
scrapers/raw_scrapes/
|
||||
|
||||
compose.override.yml
|
||||
config/license.json
|
||||
|
|
|
|||
|
|
@ -61,6 +61,13 @@ def _startup() -> None:
|
|||
|
||||
_startup()
|
||||
|
||||
# Silent license refresh on startup — no-op if unreachable
|
||||
try:
|
||||
from scripts.license import refresh_if_needed as _refresh_license
|
||||
_refresh_license()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ── First-run wizard gate ───────────────────────────────────────────────────────
|
||||
from scripts.user_profile import UserProfile as _UserProfile
|
||||
_USER_YAML = Path(__file__).parent.parent / "config" / "user.yaml"
|
||||
|
|
|
|||
|
|
@ -89,12 +89,12 @@ _show_dev_tab = _dev_mode or bool(_u_for_dev.get("dev_tier_override"))
|
|||
_tab_names = [
|
||||
"👤 My Profile", "🔎 Search", "🤖 LLM Backends", "📚 Notion",
|
||||
"🔌 Services", "📝 Resume Profile", "📧 Email", "🏷️ Skills",
|
||||
"🔗 Integrations", "🎯 Fine-Tune"
|
||||
"🔗 Integrations", "🎯 Fine-Tune", "🔑 License"
|
||||
]
|
||||
if _show_dev_tab:
|
||||
_tab_names.append("🛠️ Developer")
|
||||
_all_tabs = st.tabs(_tab_names)
|
||||
tab_profile, tab_search, tab_llm, tab_notion, tab_services, tab_resume, tab_email, tab_skills, tab_integrations, tab_finetune = _all_tabs[:10]
|
||||
tab_profile, tab_search, tab_llm, tab_notion, tab_services, tab_resume, tab_email, tab_skills, tab_integrations, tab_finetune, tab_license = _all_tabs[:11]
|
||||
|
||||
with tab_profile:
|
||||
from scripts.user_profile import UserProfile as _UP, _DEFAULTS as _UP_DEFAULTS
|
||||
|
|
@ -1129,6 +1129,53 @@ with tab_finetune:
|
|||
if col_refresh.button("🔄 Check model status", key="ft_refresh3"):
|
||||
st.rerun()
|
||||
|
||||
# ── License tab ───────────────────────────────────────────────────────────────
|
||||
with tab_license:
|
||||
st.subheader("🔑 License")
|
||||
|
||||
from scripts.license import (
|
||||
verify_local as _verify_local,
|
||||
activate as _activate,
|
||||
deactivate as _deactivate,
|
||||
_DEFAULT_LICENSE_PATH,
|
||||
_DEFAULT_PUBLIC_KEY_PATH,
|
||||
)
|
||||
|
||||
_lic = _verify_local()
|
||||
|
||||
if _lic:
|
||||
_grace_note = " _(grace period active)_" if _lic.get("in_grace") else ""
|
||||
st.success(f"**{_lic['tier'].title()} tier** active{_grace_note}")
|
||||
try:
|
||||
import json as _json
|
||||
_key_display = _json.loads(_DEFAULT_LICENSE_PATH.read_text()).get("key_display", "—")
|
||||
except Exception:
|
||||
_key_display = "—"
|
||||
st.caption(f"Key: `{_key_display}`")
|
||||
if _lic.get("notice"):
|
||||
st.info(_lic["notice"])
|
||||
if st.button("Deactivate this machine", type="secondary", key="lic_deactivate"):
|
||||
_deactivate()
|
||||
st.success("Deactivated. Restart the app to apply.")
|
||||
st.rerun()
|
||||
else:
|
||||
st.info("No active license — running on **free tier**.")
|
||||
st.caption("Enter a license key to unlock paid features.")
|
||||
_key_input = st.text_input(
|
||||
"License key",
|
||||
placeholder="CFG-PRNG-XXXX-XXXX-XXXX",
|
||||
label_visibility="collapsed",
|
||||
key="lic_key_input",
|
||||
)
|
||||
if st.button("Activate", disabled=not (_key_input or "").strip(), key="lic_activate"):
|
||||
with st.spinner("Activating…"):
|
||||
try:
|
||||
result = _activate(_key_input.strip())
|
||||
st.success(f"Activated! Tier: **{result['tier']}**")
|
||||
st.rerun()
|
||||
except Exception as _e:
|
||||
st.error(f"Activation failed: {_e}")
|
||||
|
||||
# ── Developer tab ─────────────────────────────────────────────────────────────
|
||||
if _show_dev_tab:
|
||||
with _all_tabs[-1]:
|
||||
|
|
|
|||
|
|
@ -65,3 +65,32 @@ def tier_label(feature: str) -> str:
|
|||
if required is None:
|
||||
return ""
|
||||
return "🔒 Paid" if required == "paid" else "⭐ Premium"
|
||||
|
||||
|
||||
def effective_tier(
|
||||
profile=None,
|
||||
license_path=None,
|
||||
public_key_path=None,
|
||||
) -> str:
|
||||
"""Return the effective tier for this installation.
|
||||
|
||||
Priority:
|
||||
1. profile.dev_tier_override (developer mode override)
|
||||
2. License JWT verification (offline RS256 check)
|
||||
3. "free" (fallback)
|
||||
|
||||
license_path and public_key_path default to production paths when None.
|
||||
Pass explicit paths in tests to avoid touching real files.
|
||||
"""
|
||||
if profile and getattr(profile, "dev_tier_override", None):
|
||||
return profile.dev_tier_override
|
||||
|
||||
from scripts.license import effective_tier as _license_tier
|
||||
from pathlib import Path as _Path
|
||||
|
||||
kwargs = {}
|
||||
if license_path is not None:
|
||||
kwargs["license_path"] = _Path(license_path)
|
||||
if public_key_path is not None:
|
||||
kwargs["public_key_path"] = _Path(public_key_path)
|
||||
return _license_tier(**kwargs)
|
||||
|
|
|
|||
275
scripts/license.py
Normal file
275
scripts/license.py
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
"""
|
||||
CircuitForge license client for Peregrine.
|
||||
|
||||
Activates against the license server, caches a signed JWT locally,
|
||||
and verifies tier offline using the embedded RS256 public key.
|
||||
|
||||
All functions accept override paths for testing; production code uses
|
||||
the module-level defaults.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import socket
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
_HERE = Path(__file__).parent
|
||||
_DEFAULT_LICENSE_PATH = _HERE.parent / "config" / "license.json"
|
||||
_DEFAULT_PUBLIC_KEY_PATH = _HERE / "license_public_key.pem"
|
||||
_LICENSE_SERVER = "https://license.circuitforge.tech"
|
||||
_PRODUCT = "peregrine"
|
||||
_REFRESH_THRESHOLD_DAYS = 5
|
||||
_GRACE_PERIOD_DAYS = 7
|
||||
|
||||
|
||||
# ── Machine fingerprint ────────────────────────────────────────────────────────
|
||||
|
||||
def _machine_id() -> str:
|
||||
raw = f"{socket.gethostname()}-{uuid.getnode()}"
|
||||
return hashlib.sha256(raw.encode()).hexdigest()[:32]
|
||||
|
||||
|
||||
# ── License file helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def _read_license(license_path: Path) -> dict | None:
|
||||
try:
|
||||
return json.loads(license_path.read_text())
|
||||
except (FileNotFoundError, json.JSONDecodeError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def _write_license(data: dict, license_path: Path) -> None:
|
||||
license_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
license_path.write_text(json.dumps(data, indent=2))
|
||||
|
||||
|
||||
# ── Core verify ───────────────────────────────────────────────────────────────
|
||||
|
||||
def verify_local(
|
||||
license_path: Path = _DEFAULT_LICENSE_PATH,
|
||||
public_key_path: Path = _DEFAULT_PUBLIC_KEY_PATH,
|
||||
) -> dict | None:
|
||||
"""Verify the cached JWT offline. Returns payload dict or None (= free tier).
|
||||
|
||||
Returned dict has keys: tier, in_grace (bool), sub, product, notice (optional).
|
||||
"""
|
||||
stored = _read_license(license_path)
|
||||
if not stored or not stored.get("jwt"):
|
||||
return None
|
||||
|
||||
if not public_key_path.exists():
|
||||
return None
|
||||
|
||||
public_key = public_key_path.read_bytes()
|
||||
|
||||
try:
|
||||
payload = pyjwt.decode(stored["jwt"], public_key, algorithms=["RS256"])
|
||||
if payload.get("product") != _PRODUCT:
|
||||
return None
|
||||
return {**payload, "in_grace": False}
|
||||
|
||||
except pyjwt.exceptions.ExpiredSignatureError:
|
||||
# JWT expired — check local grace period before requiring a refresh
|
||||
grace_until_str = stored.get("grace_until")
|
||||
if not grace_until_str:
|
||||
return None
|
||||
try:
|
||||
grace_until = datetime.fromisoformat(grace_until_str)
|
||||
if grace_until.tzinfo is None:
|
||||
grace_until = grace_until.replace(tzinfo=timezone.utc)
|
||||
except ValueError:
|
||||
return None
|
||||
if datetime.now(timezone.utc) > grace_until:
|
||||
return None
|
||||
# Decode without expiry check to recover the payload
|
||||
try:
|
||||
payload = pyjwt.decode(
|
||||
stored["jwt"], public_key,
|
||||
algorithms=["RS256"],
|
||||
options={"verify_exp": False},
|
||||
)
|
||||
if payload.get("product") != _PRODUCT:
|
||||
return None
|
||||
return {**payload, "in_grace": True}
|
||||
except pyjwt.exceptions.PyJWTError:
|
||||
return None
|
||||
|
||||
except pyjwt.exceptions.PyJWTError:
|
||||
return None
|
||||
|
||||
|
||||
def effective_tier(
|
||||
license_path: Path = _DEFAULT_LICENSE_PATH,
|
||||
public_key_path: Path = _DEFAULT_PUBLIC_KEY_PATH,
|
||||
) -> str:
|
||||
"""Return the effective tier string. Falls back to 'free' on any problem."""
|
||||
result = verify_local(license_path=license_path, public_key_path=public_key_path)
|
||||
if result is None:
|
||||
return "free"
|
||||
return result.get("tier", "free")
|
||||
|
||||
|
||||
# ── Network operations (all fire-and-forget or explicit) ──────────────────────
|
||||
|
||||
def activate(
|
||||
key: str,
|
||||
license_path: Path = _DEFAULT_LICENSE_PATH,
|
||||
public_key_path: Path = _DEFAULT_PUBLIC_KEY_PATH,
|
||||
app_version: str | None = None,
|
||||
) -> dict:
|
||||
"""Activate a license key. Returns response dict. Raises on failure."""
|
||||
import httpx
|
||||
mid = _machine_id()
|
||||
resp = httpx.post(
|
||||
f"{_LICENSE_SERVER}/activate",
|
||||
json={
|
||||
"key": key,
|
||||
"machine_id": mid,
|
||||
"product": _PRODUCT,
|
||||
"app_version": app_version,
|
||||
"platform": _detect_platform(),
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
stored = {
|
||||
"jwt": data["jwt"],
|
||||
"key_display": key,
|
||||
"tier": data["tier"],
|
||||
"valid_until": data.get("valid_until"),
|
||||
"machine_id": mid,
|
||||
"last_refresh": datetime.now(timezone.utc).isoformat(),
|
||||
"grace_until": None,
|
||||
}
|
||||
_write_license(stored, license_path)
|
||||
return data
|
||||
|
||||
|
||||
def deactivate(
|
||||
license_path: Path = _DEFAULT_LICENSE_PATH,
|
||||
) -> None:
|
||||
"""Deactivate this machine. Deletes license.json."""
|
||||
import httpx
|
||||
stored = _read_license(license_path)
|
||||
if not stored:
|
||||
return
|
||||
try:
|
||||
httpx.post(
|
||||
f"{_LICENSE_SERVER}/deactivate",
|
||||
json={"jwt": stored["jwt"], "machine_id": stored.get("machine_id", _machine_id())},
|
||||
timeout=10,
|
||||
)
|
||||
except Exception:
|
||||
pass # best-effort
|
||||
license_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def refresh_if_needed(
|
||||
license_path: Path = _DEFAULT_LICENSE_PATH,
|
||||
public_key_path: Path = _DEFAULT_PUBLIC_KEY_PATH,
|
||||
) -> None:
|
||||
"""Silently refresh JWT if it expires within threshold. No-op on network failure."""
|
||||
stored = _read_license(license_path)
|
||||
if not stored or not stored.get("jwt"):
|
||||
return
|
||||
try:
|
||||
payload = pyjwt.decode(
|
||||
stored["jwt"], public_key_path.read_bytes(), algorithms=["RS256"]
|
||||
)
|
||||
exp = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
|
||||
if exp - datetime.now(timezone.utc) > timedelta(days=_REFRESH_THRESHOLD_DAYS):
|
||||
return
|
||||
except pyjwt.exceptions.ExpiredSignatureError:
|
||||
# Already expired — try to refresh anyway, set grace if unreachable
|
||||
pass
|
||||
except Exception:
|
||||
return
|
||||
|
||||
try:
|
||||
import httpx
|
||||
resp = httpx.post(
|
||||
f"{_LICENSE_SERVER}/refresh",
|
||||
json={"jwt": stored["jwt"], "machine_id": stored.get("machine_id", _machine_id())},
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
stored["jwt"] = data["jwt"]
|
||||
stored["tier"] = data["tier"]
|
||||
stored["last_refresh"] = datetime.now(timezone.utc).isoformat()
|
||||
stored["grace_until"] = None
|
||||
_write_license(stored, license_path)
|
||||
except Exception:
|
||||
# Server unreachable — set grace period if not already set
|
||||
if not stored.get("grace_until"):
|
||||
grace = datetime.now(timezone.utc) + timedelta(days=_GRACE_PERIOD_DAYS)
|
||||
stored["grace_until"] = grace.isoformat()
|
||||
_write_license(stored, license_path)
|
||||
|
||||
|
||||
def report_usage(
|
||||
event_type: str,
|
||||
metadata: dict | None = None,
|
||||
license_path: Path = _DEFAULT_LICENSE_PATH,
|
||||
) -> None:
|
||||
"""Fire-and-forget usage telemetry. Never blocks, never raises."""
|
||||
stored = _read_license(license_path)
|
||||
if not stored or not stored.get("jwt"):
|
||||
return
|
||||
|
||||
def _send():
|
||||
try:
|
||||
import httpx
|
||||
httpx.post(
|
||||
f"{_LICENSE_SERVER}/usage",
|
||||
json={"event_type": event_type, "product": _PRODUCT, "metadata": metadata or {}},
|
||||
headers={"Authorization": f"Bearer {stored['jwt']}"},
|
||||
timeout=5,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
threading.Thread(target=_send, daemon=True).start()
|
||||
|
||||
|
||||
def report_flag(
|
||||
flag_type: str,
|
||||
details: dict | None = None,
|
||||
license_path: Path = _DEFAULT_LICENSE_PATH,
|
||||
) -> None:
|
||||
"""Fire-and-forget violation report. Never blocks, never raises."""
|
||||
stored = _read_license(license_path)
|
||||
if not stored or not stored.get("jwt"):
|
||||
return
|
||||
|
||||
def _send():
|
||||
try:
|
||||
import httpx
|
||||
httpx.post(
|
||||
f"{_LICENSE_SERVER}/flag",
|
||||
json={"flag_type": flag_type, "product": _PRODUCT, "details": details or {}},
|
||||
headers={"Authorization": f"Bearer {stored['jwt']}"},
|
||||
timeout=5,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
threading.Thread(target=_send, daemon=True).start()
|
||||
|
||||
|
||||
def _detect_platform() -> str:
|
||||
import sys
|
||||
if sys.platform.startswith("linux"):
|
||||
return "linux"
|
||||
if sys.platform == "darwin":
|
||||
return "macos"
|
||||
if sys.platform == "win32":
|
||||
return "windows"
|
||||
return "unknown"
|
||||
9
scripts/license_public_key.pem
Normal file
9
scripts/license_public_key.pem
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
-----BEGIN PUBLIC KEY-----
|
||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAr9kLOyfJbm1QMFGdsC8b
|
||||
LR9xm4bCZ9L63o8doejfMHNliQrUxmmKPKYF4o3dE73Y9og7MrmQRN1pvFgvcVAj
|
||||
o7GB6os5hSf8DDLYSFa2uGwoWOTs9uhDHKcB32T7nI3PCq0hqIoLfwfc9noi+MWh
|
||||
UP8APzgQe7iKjbr+l7wXFM7UhybZ30CYZ10jgdLyP/PMVqVpgWSBm/I84FT+krUS
|
||||
pvx+9KEwzdwoHdZltTwFHr29RISsk4161R0+1pJmXBpa4EsKhlHvrXEpHDssG68h
|
||||
nDeqdFN20EJhf6L0Gab6UYGJqkaMecrdYrij+6Xu5jx3awn7mIsxCkj0jXtmNPZJ
|
||||
LQIDAQAB
|
||||
-----END PUBLIC KEY-----
|
||||
121
tests/test_license.py
Normal file
121
tests/test_license.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
import json
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
import jwt as pyjwt
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_keys(tmp_path):
|
||||
"""Generate test RSA keypair and return (private_pem, public_pem, public_path)."""
|
||||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
private_pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
public_pem = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
public_path = tmp_path / "test_public.pem"
|
||||
public_path.write_bytes(public_pem)
|
||||
return private_pem, public_pem, public_path
|
||||
|
||||
|
||||
def _make_jwt(private_pem: bytes, tier: str = "paid",
|
||||
product: str = "peregrine",
|
||||
exp_delta_days: int = 30,
|
||||
machine: str = "test-machine") -> str:
|
||||
now = datetime.now(timezone.utc)
|
||||
payload = {
|
||||
"sub": "CFG-PRNG-TEST-TEST-TEST",
|
||||
"product": product,
|
||||
"tier": tier,
|
||||
"seats": 1,
|
||||
"machine": machine,
|
||||
"iat": now,
|
||||
"exp": now + timedelta(days=exp_delta_days),
|
||||
}
|
||||
return pyjwt.encode(payload, private_pem, algorithm="RS256")
|
||||
|
||||
|
||||
def _write_license(tmp_path, jwt_token: str, grace_until: str | None = None) -> Path:
|
||||
data = {
|
||||
"jwt": jwt_token,
|
||||
"key_display": "CFG-PRNG-TEST-TEST-TEST",
|
||||
"tier": "paid",
|
||||
"valid_until": None,
|
||||
"machine_id": "test-machine",
|
||||
"last_refresh": datetime.now(timezone.utc).isoformat(),
|
||||
"grace_until": grace_until,
|
||||
}
|
||||
p = tmp_path / "license.json"
|
||||
p.write_text(json.dumps(data))
|
||||
return p
|
||||
|
||||
|
||||
class TestVerifyLocal:
|
||||
def test_valid_jwt_returns_tier(self, test_keys, tmp_path):
|
||||
private_pem, _, public_path = test_keys
|
||||
token = _make_jwt(private_pem)
|
||||
license_path = _write_license(tmp_path, token)
|
||||
from scripts.license import verify_local
|
||||
result = verify_local(license_path=license_path, public_key_path=public_path)
|
||||
assert result is not None
|
||||
assert result["tier"] == "paid"
|
||||
|
||||
def test_missing_file_returns_none(self, tmp_path):
|
||||
from scripts.license import verify_local
|
||||
result = verify_local(license_path=tmp_path / "missing.json",
|
||||
public_key_path=tmp_path / "key.pem")
|
||||
assert result is None
|
||||
|
||||
def test_wrong_product_returns_none(self, test_keys, tmp_path):
|
||||
private_pem, _, public_path = test_keys
|
||||
token = _make_jwt(private_pem, product="falcon")
|
||||
license_path = _write_license(tmp_path, token)
|
||||
from scripts.license import verify_local
|
||||
result = verify_local(license_path=license_path, public_key_path=public_path)
|
||||
assert result is None
|
||||
|
||||
def test_expired_within_grace_returns_tier(self, test_keys, tmp_path):
|
||||
private_pem, _, public_path = test_keys
|
||||
token = _make_jwt(private_pem, exp_delta_days=-1)
|
||||
grace_until = (datetime.now(timezone.utc) + timedelta(days=3)).isoformat()
|
||||
license_path = _write_license(tmp_path, token, grace_until=grace_until)
|
||||
from scripts.license import verify_local
|
||||
result = verify_local(license_path=license_path, public_key_path=public_path)
|
||||
assert result is not None
|
||||
assert result["tier"] == "paid"
|
||||
assert result["in_grace"] is True
|
||||
|
||||
def test_expired_past_grace_returns_none(self, test_keys, tmp_path):
|
||||
private_pem, _, public_path = test_keys
|
||||
token = _make_jwt(private_pem, exp_delta_days=-10)
|
||||
grace_until = (datetime.now(timezone.utc) - timedelta(days=1)).isoformat()
|
||||
license_path = _write_license(tmp_path, token, grace_until=grace_until)
|
||||
from scripts.license import verify_local
|
||||
result = verify_local(license_path=license_path, public_key_path=public_path)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestEffectiveTier:
|
||||
def test_returns_free_when_no_license(self, tmp_path):
|
||||
from scripts.license import effective_tier
|
||||
result = effective_tier(
|
||||
license_path=tmp_path / "missing.json",
|
||||
public_key_path=tmp_path / "key.pem",
|
||||
)
|
||||
assert result == "free"
|
||||
|
||||
def test_returns_tier_from_valid_jwt(self, test_keys, tmp_path):
|
||||
private_pem, _, public_path = test_keys
|
||||
token = _make_jwt(private_pem, tier="premium")
|
||||
license_path = _write_license(tmp_path, token)
|
||||
from scripts.license import effective_tier
|
||||
result = effective_tier(license_path=license_path, public_key_path=public_path)
|
||||
assert result == "premium"
|
||||
69
tests/test_license_tier_integration.py
Normal file
69
tests/test_license_tier_integration.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
import json
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import patch
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
import jwt as pyjwt
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def license_env(tmp_path):
|
||||
"""Returns (private_pem, public_path, license_path) for tier integration tests."""
|
||||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
private_pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
public_pem = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
public_path = tmp_path / "public.pem"
|
||||
public_path.write_bytes(public_pem)
|
||||
license_path = tmp_path / "license.json"
|
||||
return private_pem, public_path, license_path
|
||||
|
||||
|
||||
def _write_jwt_license(license_path, private_pem, tier="paid", days=30):
|
||||
now = datetime.now(timezone.utc)
|
||||
token = pyjwt.encode({
|
||||
"sub": "CFG-PRNG-TEST", "product": "peregrine", "tier": tier,
|
||||
"iat": now, "exp": now + timedelta(days=days),
|
||||
}, private_pem, algorithm="RS256")
|
||||
license_path.write_text(json.dumps({"jwt": token, "grace_until": None}))
|
||||
|
||||
|
||||
def test_effective_tier_free_without_license(tmp_path):
|
||||
from app.wizard.tiers import effective_tier
|
||||
tier = effective_tier(
|
||||
profile=None,
|
||||
license_path=tmp_path / "missing.json",
|
||||
public_key_path=tmp_path / "key.pem",
|
||||
)
|
||||
assert tier == "free"
|
||||
|
||||
|
||||
def test_effective_tier_paid_with_valid_license(license_env):
|
||||
private_pem, public_path, license_path = license_env
|
||||
_write_jwt_license(license_path, private_pem, tier="paid")
|
||||
from app.wizard.tiers import effective_tier
|
||||
tier = effective_tier(profile=None, license_path=license_path,
|
||||
public_key_path=public_path)
|
||||
assert tier == "paid"
|
||||
|
||||
|
||||
def test_effective_tier_dev_override_takes_precedence(license_env):
|
||||
"""dev_tier_override wins even when a valid license is present."""
|
||||
private_pem, public_path, license_path = license_env
|
||||
_write_jwt_license(license_path, private_pem, tier="paid")
|
||||
|
||||
class FakeProfile:
|
||||
dev_tier_override = "premium"
|
||||
|
||||
from app.wizard.tiers import effective_tier
|
||||
tier = effective_tier(profile=FakeProfile(), license_path=license_path,
|
||||
public_key_path=public_path)
|
||||
assert tier == "premium"
|
||||
Loading…
Reference in a new issue