feat(text): classifier backend + PII filter
Add ClassifierBackend (NER/PII via transformers token-classification pipeline) and TextFilter (redact / detect / spans modes). MockClassifierBackend provides deterministic PII spans for tests and CI without GPU. Enables privacy-safe pre-screening before LLM inference.
This commit is contained in:
parent
93ab528261
commit
0c43e95991
4 changed files with 399 additions and 0 deletions
88
circuitforge_core/text/backends/classifier.py
Normal file
88
circuitforge_core/text/backends/classifier.py
Normal file
|
|
@ -0,0 +1,88 @@
|
||||||
|
# circuitforge_core/text/backends/classifier.py — HuggingFace token-classification backend
|
||||||
|
#
|
||||||
|
# BSL 1.1. Requires torch + transformers.
|
||||||
|
# Install: pip install circuitforge-core[text-transformers]
|
||||||
|
#
|
||||||
|
# Wraps pipeline("token-classification") for PII/entity detection.
|
||||||
|
# Returns spans with char offsets, entity labels, and confidence scores.
|
||||||
|
# Use make_classifier_backend() from base.py to instantiate.
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifierBackend:
|
||||||
|
"""
|
||||||
|
HuggingFace token-classification backend for PII detection and entity labeling.
|
||||||
|
|
||||||
|
Loads any token-classification model from HuggingFace Hub or a local checkpoint.
|
||||||
|
Returns aggregated entity spans with char offsets — suitable for redaction or audit.
|
||||||
|
|
||||||
|
Aggregation strategy "simple" merges consecutive BIO-tagged subwords into word-level
|
||||||
|
spans and strips the B-/I- prefixes so callers see "NAME" not "B-NAME".
|
||||||
|
|
||||||
|
Requires: pip install circuitforge-core[text-transformers]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_path: str) -> None:
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from transformers import pipeline as hf_pipeline
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"torch and transformers are required for ClassifierBackend. "
|
||||||
|
"Install with: pip install circuitforge-core[text-transformers]"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
device = 0 if torch.cuda.is_available() else -1
|
||||||
|
cuda_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
|
||||||
|
if cuda_devices:
|
||||||
|
device = 0
|
||||||
|
|
||||||
|
logger.info("Loading classifier model %s on device %s", model_path, device)
|
||||||
|
|
||||||
|
self._pipeline = hf_pipeline(
|
||||||
|
"token-classification",
|
||||||
|
model=model_path,
|
||||||
|
aggregation_strategy="simple",
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self._model_path = model_path
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
return self._model_path.split("/")[-1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vram_mb(self) -> int:
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.cuda.memory_allocated() // (1024 * 1024)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def classify(self, text: str) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Run token classification synchronously.
|
||||||
|
|
||||||
|
Returns a list of entity dicts with keys:
|
||||||
|
entity_group: str — label without BIO prefix (e.g. "NAME", "EMAIL")
|
||||||
|
score: float — aggregated confidence
|
||||||
|
word: str — matched text span
|
||||||
|
start: int — char offset (start, inclusive)
|
||||||
|
end: int — char offset (end, exclusive)
|
||||||
|
"""
|
||||||
|
results: list[dict[str, Any]] = self._pipeline(text)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def classify_async(self, text: str) -> list[dict[str, Any]]:
|
||||||
|
"""Async classify — runs pipeline in thread pool to avoid blocking the event loop."""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(None, self.classify, text)
|
||||||
|
|
@ -102,3 +102,49 @@ class MockTextBackend:
|
||||||
# Format messages into a simple prompt for the mock response
|
# Format messages into a simple prompt for the mock response
|
||||||
prompt = "\n".join(f"{m.role}: {m.content}" for m in messages)
|
prompt = "\n".join(f"{m.role}: {m.content}" for m in messages)
|
||||||
return self.generate(prompt, max_tokens=max_tokens, temperature=temperature)
|
return self.generate(prompt, max_tokens=max_tokens, temperature=temperature)
|
||||||
|
|
||||||
|
|
||||||
|
# Synthetic PII spans injected by MockClassifierBackend — predictable in tests.
|
||||||
|
_MOCK_SPANS = [
|
||||||
|
{
|
||||||
|
"entity_group": "NAME",
|
||||||
|
"score": 0.99,
|
||||||
|
"word": "Jane Doe",
|
||||||
|
"start": 0,
|
||||||
|
"end": 8,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"entity_group": "EMAIL",
|
||||||
|
"score": 0.97,
|
||||||
|
"word": "jane@example.com",
|
||||||
|
"start": 18,
|
||||||
|
"end": 34,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MockClassifierBackend:
|
||||||
|
"""
|
||||||
|
Deterministic mock classifier backend for development and CI.
|
||||||
|
|
||||||
|
Always returns the same two synthetic PII spans regardless of input.
|
||||||
|
Allows filter.py logic (redaction, span conversion) to be tested without
|
||||||
|
a real model or GPU.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = "mock-classifier") -> None:
|
||||||
|
self._model_name = model_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
return self._model_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vram_mb(self) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def classify(self, text: str) -> list[dict]:
|
||||||
|
return list(_MOCK_SPANS)
|
||||||
|
|
||||||
|
async def classify_async(self, text: str) -> list[dict]:
|
||||||
|
return self.classify(text)
|
||||||
|
|
|
||||||
114
circuitforge_core/text/filter.py
Normal file
114
circuitforge_core/text/filter.py
Normal file
|
|
@ -0,0 +1,114 @@
|
||||||
|
# circuitforge_core/text/filter.py — PII detection and redaction
|
||||||
|
#
|
||||||
|
# BSL 1.1. Products import PIIFilter for pre-send redaction and audit trails.
|
||||||
|
# Requires a running cf-filter service (or ClassifierBackend for in-process use).
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from circuitforge_core.text.backends.base import FilterBackend, make_classifier_backend
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PIISpan:
|
||||||
|
"""A single detected PII entity in the source text."""
|
||||||
|
|
||||||
|
label: str # e.g. NAME | EMAIL | PHONE_NUM | ADDRESS | SSN | DOB | IP_ADDRESS
|
||||||
|
start: int # char offset (inclusive) in original_text
|
||||||
|
end: int # char offset (exclusive) in original_text
|
||||||
|
text: str # original span text
|
||||||
|
score: float # confidence score from the classifier
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FilterResult:
|
||||||
|
"""Output of PIIFilter.filter().
|
||||||
|
|
||||||
|
``redacted_text``: safe-to-send copy with each span replaced by ``[LABEL]``.
|
||||||
|
``spans``: all detected entities — for audit logs or caller-side decisions.
|
||||||
|
``original_text``: the input text (stored for round-trip comparisons).
|
||||||
|
"""
|
||||||
|
|
||||||
|
redacted_text: str
|
||||||
|
spans: list[PIISpan] = field(default_factory=list)
|
||||||
|
original_text: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def _redact(text: str, spans: list[PIISpan]) -> str:
|
||||||
|
"""Replace each span in text with ``[LABEL]``, processing right-to-left so
|
||||||
|
earlier offsets remain valid after each substitution."""
|
||||||
|
result = text
|
||||||
|
for span in sorted(spans, key=lambda s: s.start, reverse=True):
|
||||||
|
result = result[: span.start] + f"[{span.label}]" + result[span.end :]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _spans_from_pipeline(raw: list[dict[str, Any]]) -> list[PIISpan]:
|
||||||
|
"""Convert raw pipeline output dicts into typed PIISpan objects.
|
||||||
|
|
||||||
|
Pipeline returns dicts with keys: entity_group, score, word, start, end.
|
||||||
|
Normalise label to uppercase and strip any residual BIO prefixes.
|
||||||
|
"""
|
||||||
|
spans: list[PIISpan] = []
|
||||||
|
for item in raw:
|
||||||
|
label = re.sub(r"^[BI]-", "", item.get("entity_group", "")).upper()
|
||||||
|
spans.append(
|
||||||
|
PIISpan(
|
||||||
|
label=label,
|
||||||
|
start=int(item["start"]),
|
||||||
|
end=int(item["end"]),
|
||||||
|
text=item.get("word", ""),
|
||||||
|
score=float(item.get("score", 0.0)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return spans
|
||||||
|
|
||||||
|
|
||||||
|
class PIIFilter:
|
||||||
|
"""
|
||||||
|
High-level PII filter backed by a token-classification model.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
pii_filter = PIIFilter.from_model("openai/privacy-filter")
|
||||||
|
result = await pii_filter.filter_async(resume_text)
|
||||||
|
safe_text = result.redacted_text # send to cloud LLM
|
||||||
|
spans = result.spans # store for audit trail
|
||||||
|
|
||||||
|
For in-process use (no cf-orch), pass a model path and it loads directly.
|
||||||
|
For service-backed use, see PIIFilter.from_backend().
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, backend: FilterBackend) -> None:
|
||||||
|
self._backend = backend
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_model(cls, model_path: str) -> "PIIFilter":
|
||||||
|
"""Load a classifier model in-process (no cf-orch required)."""
|
||||||
|
return cls(make_classifier_backend(model_path))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_backend(cls, backend: FilterBackend) -> "PIIFilter":
|
||||||
|
"""Wrap an already-constructed FilterBackend."""
|
||||||
|
return cls(backend)
|
||||||
|
|
||||||
|
def filter(self, text: str) -> FilterResult:
|
||||||
|
"""Synchronous filter — blocks until classification is complete."""
|
||||||
|
raw = self._backend.classify(text)
|
||||||
|
spans = _spans_from_pipeline(raw)
|
||||||
|
return FilterResult(
|
||||||
|
redacted_text=_redact(text, spans),
|
||||||
|
spans=spans,
|
||||||
|
original_text=text,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def filter_async(self, text: str) -> FilterResult:
|
||||||
|
"""Async filter — runs classifier in thread pool."""
|
||||||
|
raw = await self._backend.classify_async(text)
|
||||||
|
spans = _spans_from_pipeline(raw)
|
||||||
|
return FilterResult(
|
||||||
|
redacted_text=_redact(text, spans),
|
||||||
|
spans=spans,
|
||||||
|
original_text=text,
|
||||||
|
)
|
||||||
151
tests/test_text/test_classifier.py
Normal file
151
tests/test_text/test_classifier.py
Normal file
|
|
@ -0,0 +1,151 @@
|
||||||
|
# tests/test_text/test_classifier.py — PII filter backend and endpoint tests
|
||||||
|
import pytest
|
||||||
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
|
||||||
|
from circuitforge_core.text.backends.mock import MockClassifierBackend
|
||||||
|
from circuitforge_core.text.filter import PIIFilter, PIISpan, FilterResult, _redact, _spans_from_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
# ── Unit: _spans_from_pipeline ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_spans_from_pipeline_normalises_bio_prefix():
|
||||||
|
raw = [{"entity_group": "B-NAME", "score": 0.9, "word": "Alice", "start": 0, "end": 5}]
|
||||||
|
spans = _spans_from_pipeline(raw)
|
||||||
|
assert spans[0].label == "NAME"
|
||||||
|
|
||||||
|
|
||||||
|
def test_spans_from_pipeline_uppercase():
|
||||||
|
raw = [{"entity_group": "email", "score": 0.8, "word": "a@b.com", "start": 10, "end": 17}]
|
||||||
|
spans = _spans_from_pipeline(raw)
|
||||||
|
assert spans[0].label == "EMAIL"
|
||||||
|
|
||||||
|
|
||||||
|
def test_spans_from_pipeline_returns_typed_objects():
|
||||||
|
raw = [{"entity_group": "PHONE_NUM", "score": 0.95, "word": "555-1234", "start": 5, "end": 13}]
|
||||||
|
spans = _spans_from_pipeline(raw)
|
||||||
|
assert isinstance(spans[0], PIISpan)
|
||||||
|
assert spans[0].score == pytest.approx(0.95)
|
||||||
|
assert spans[0].start == 5
|
||||||
|
assert spans[0].end == 13
|
||||||
|
|
||||||
|
|
||||||
|
# ── Unit: _redact ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_redact_replaces_spans():
|
||||||
|
text = "Call Alice at 555-1234 now"
|
||||||
|
spans = [
|
||||||
|
PIISpan(label="NAME", start=5, end=10, text="Alice", score=0.99),
|
||||||
|
PIISpan(label="PHONE_NUM", start=14, end=22, text="555-1234", score=0.97),
|
||||||
|
]
|
||||||
|
assert _redact(text, spans) == "Call [NAME] at [PHONE_NUM] now"
|
||||||
|
|
||||||
|
|
||||||
|
def test_redact_handles_overlapping_order():
|
||||||
|
# Spans processed right-to-left — earlier offsets must still be valid
|
||||||
|
text = "Jane Doe jane@example.com"
|
||||||
|
spans = [
|
||||||
|
PIISpan(label="NAME", start=0, end=8, text="Jane Doe", score=0.99),
|
||||||
|
PIISpan(label="EMAIL", start=9, end=25, text="jane@example.com", score=0.97),
|
||||||
|
]
|
||||||
|
result = _redact(text, spans)
|
||||||
|
assert "[NAME]" in result
|
||||||
|
assert "[EMAIL]" in result
|
||||||
|
assert "Jane Doe" not in result
|
||||||
|
assert "jane@example.com" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_redact_no_spans_returns_original():
|
||||||
|
text = "No PII here"
|
||||||
|
assert _redact(text, []) == text
|
||||||
|
|
||||||
|
|
||||||
|
# ── Unit: PIIFilter with MockClassifierBackend ────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_pii_filter_sync():
|
||||||
|
backend = MockClassifierBackend()
|
||||||
|
pii_filter = PIIFilter.from_backend(backend)
|
||||||
|
# Mock backend returns spans for "Jane Doe" at 0-8 and "jane@example.com" at 18-34
|
||||||
|
result = pii_filter.filter("Jane Doe emailed jane@example.com today")
|
||||||
|
assert isinstance(result, FilterResult)
|
||||||
|
assert "[NAME]" in result.redacted_text
|
||||||
|
assert "[EMAIL]" in result.redacted_text
|
||||||
|
assert len(result.spans) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_pii_filter_preserves_original_text():
|
||||||
|
backend = MockClassifierBackend()
|
||||||
|
pii_filter = PIIFilter.from_backend(backend)
|
||||||
|
text = "Jane Doe emailed jane@example.com today"
|
||||||
|
result = pii_filter.filter(text)
|
||||||
|
assert result.original_text == text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pii_filter_async():
|
||||||
|
backend = MockClassifierBackend()
|
||||||
|
pii_filter = PIIFilter.from_backend(backend)
|
||||||
|
result = await pii_filter.filter_async("Jane Doe emailed jane@example.com today")
|
||||||
|
assert "[NAME]" in result.redacted_text
|
||||||
|
assert len(result.spans) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_pii_filter_result_is_frozen():
|
||||||
|
backend = MockClassifierBackend()
|
||||||
|
pii_filter = PIIFilter.from_backend(backend)
|
||||||
|
result = pii_filter.filter("test")
|
||||||
|
with pytest.raises((AttributeError, TypeError)):
|
||||||
|
result.redacted_text = "mutated" # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Integration: /filter HTTP endpoint ───────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def classifier_app(monkeypatch):
|
||||||
|
"""cf-text app in classifier mode using mock backend."""
|
||||||
|
import os
|
||||||
|
monkeypatch.setenv("CF_TEXT_MOCK", "1")
|
||||||
|
monkeypatch.setenv("CF_TEXT_BACKEND", "classifier")
|
||||||
|
import importlib
|
||||||
|
import circuitforge_core.text.app as app_mod
|
||||||
|
importlib.reload(app_mod)
|
||||||
|
yield app_mod.create_app(model_path="openai/privacy-filter", backend="classifier", mock=False)
|
||||||
|
monkeypatch.delenv("CF_TEXT_MOCK", raising=False)
|
||||||
|
monkeypatch.delenv("CF_TEXT_BACKEND", raising=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_filter_endpoint_returns_redacted(classifier_app):
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=classifier_app), base_url="http://test") as client:
|
||||||
|
resp = await client.post("/filter", json={"text": "Jane Doe emailed jane@example.com today"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert "[NAME]" in body["redacted_text"]
|
||||||
|
assert "[EMAIL]" in body["redacted_text"]
|
||||||
|
assert len(body["spans"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_filter_endpoint_includes_original(classifier_app):
|
||||||
|
text = "Jane Doe emailed jane@example.com today"
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=classifier_app), base_url="http://test") as client:
|
||||||
|
resp = await client.post("/filter", json={"text": text})
|
||||||
|
assert resp.json()["original_text"] == text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_returns_501_in_classifier_mode(classifier_app):
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=classifier_app), base_url="http://test") as client:
|
||||||
|
resp = await client.post("/generate", json={"prompt": "hello"})
|
||||||
|
assert resp.status_code == 501
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_reports_classifier_backend(classifier_app):
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=classifier_app), base_url="http://test") as client:
|
||||||
|
resp = await client.get("/health")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["backend"] == "classifier"
|
||||||
Loading…
Reference in a new issue