turnstone/app/services/diagnose/_llm_client.py
pyr0ball 73a14bd782 fix(diagnose): add max_tokens to all LLM calls; fix reasoning card contrast
Truncation fix: call_llm() in _llm_client.py now accepts max_tokens (default
2048) and passes it in both the cf-orch task payload and the OpenAI-compat
fallback body. Hypothesizer uses max_tokens=1024 (JSON array output);
synthesizer and legacy summarize use 2048 (structured 5-section narrative).
Without this, backends use their own default (often 512 tokens), causing
mid-sentence truncation of the diagnosis output.

UI fix: reasoning card changed from bg-accent/5 border-accent/30 (opacity
modifiers on CSS variables don't compose reliably across themes) to the
callout pattern: bg-surface-raised with a solid border-l-4 border-accent.
Header label changed from text-text-dim to text-accent for visual anchoring.
Text remains text-text-primary for guaranteed contrast on both light and dark
themes.

Tracks: #56 (technical-level post-processor, filed as follow-on feature)
2026-05-27 22:23:36 -07:00

160 lines
5.5 KiB
Python

"""Shared LLM client for the multi-agent diagnose pipeline.
Both Stage 3 (RootCauseHypothesizer) and Stage 5 (SummarySynthesizer) send
messages to the same LLM backend using the same two-step pattern:
1. Try the cf-orch task endpoint → product-scoped inference routing.
2. Fall back to OpenAI-compat → direct model call by name.
Centralising here means changes to auth headers, timeouts, retry logic, or
cf-orch payload structure only need to be made once.
"""
from __future__ import annotations
import logging
import re
import httpx
logger = logging.getLogger(__name__)
# Regex that strips ```json … ``` or ``` … ``` fences from LLM output.
_JSON_FENCE_RE = re.compile(
r"^```(?:json)?\s*|\s*```$",
re.MULTILINE,
)
def extract_content(resp_json: dict) -> str | None:
"""Pull text content from an OpenAI-compat chat completion response.
Returns None when the response has no choices or empty content.
"""
choices = resp_json.get("choices") or []
if not choices:
return None
return (choices[0].get("message", {}).get("content") or "").strip() or None
def strip_json_fences(raw: str) -> str:
"""Remove markdown code fences that some LLMs wrap around JSON output.
Example: '```json\\n[...]\\n```''[...]'
"""
return _JSON_FENCE_RE.sub("", raw).strip()
def extract_first_json_array(raw: str) -> str:
"""Extract the first complete JSON array from a string.
Reasoning models (e.g. foundation-sec-8b) sometimes emit valid JSON and
then repeat it inside a markdown fence. Standard json.loads() fails on the
combined text. This function scans for the first '[' and walks to its
matching ']', handling nested structures.
Returns the extracted substring, or the original string if no array found
(so the caller's json.loads() fails with the usual error message).
"""
start = raw.find("[")
if start == -1:
return raw
depth = 0
in_string = False
escape_next = False
for i, ch in enumerate(raw[start:], start=start):
if escape_next:
escape_next = False
continue
if ch == "\\" and in_string:
escape_next = True
continue
if ch == '"':
in_string = not in_string
continue
if in_string:
continue
if ch == "[":
depth += 1
elif ch == "]":
depth -= 1
if depth == 0:
return raw[start : i + 1]
return raw # unbalanced — return as-is so caller sees the error
def call_llm(
llm_url: str,
llm_model: str,
llm_api_key: str | None,
messages: list[dict],
task_name: str = "log_analysis",
timeout: float = 120.0,
max_tokens: int = 2048,
) -> str | None:
"""Send messages to the LLM; return raw text or None on failure.
Tries the cf-orch task endpoint first (product-routed inference).
Falls back to a direct OpenAI-compat ``/v1/chat/completions`` call when:
- The task endpoint returns 404 (no assignment for this task).
- The task endpoint is unreachable (connection error, timeout, etc.).
Args:
llm_url: Base URL of the LLM backend (e.g. ``http://<YOUR_HOST_IP>:7700``).
llm_model: Model identifier used in the OpenAI-compat fallback call.
llm_api_key: Optional bearer token for authenticated endpoints.
messages: OpenAI-style message list (system + user turns).
task_name: cf-orch task name for product-routed inference (default: ``log_analysis``).
timeout: Request timeout in seconds (default: 120).
max_tokens: Maximum tokens to generate (default: 2048). Prevents mid-sentence
truncation when the backend default is lower than the output needs.
Returns:
Raw text content string, or None if both paths fail.
"""
headers: dict[str, str] = {}
if llm_api_key:
headers["Authorization"] = f"Bearer {llm_api_key}"
# --- Path 1: cf-orch task endpoint ---
task_url = f"{llm_url.rstrip('/')}/api/inference/task"
try:
resp = httpx.post(
task_url,
json={
"product": "turnstone",
"task": task_name,
"payload": {"messages": messages, "stream": False, "max_tokens": max_tokens},
},
headers=headers,
timeout=timeout,
)
if resp.status_code == 200:
return extract_content(resp.json())
if resp.status_code != 404:
resp.raise_for_status()
logger.debug(
"No task assignment for turnstone.%s — falling back to direct model",
task_name,
)
except Exception as exc: # noqa: BLE001
# Broad catch is intentional: captures network errors, timeouts, and
# any backend-specific exceptions so the pipeline can fall back.
logger.debug(
"Task endpoint unavailable (%s) — falling back to direct model", exc
)
# --- Path 2: OpenAI-compat fallback ---
try:
resp = httpx.post(
f"{llm_url.rstrip('/')}/v1/chat/completions",
json={"model": llm_model, "messages": messages, "stream": False, "max_tokens": max_tokens},
headers=headers,
timeout=timeout,
)
resp.raise_for_status()
return extract_content(resp.json())
except Exception as exc: # noqa: BLE001
logger.warning("LLM call failed (%s): %s", type(exc).__name__, exc)
return None