Truncation fix: call_llm() in _llm_client.py now accepts max_tokens (default 2048) and passes it in both the cf-orch task payload and the OpenAI-compat fallback body. Hypothesizer uses max_tokens=1024 (JSON array output); synthesizer and legacy summarize use 2048 (structured 5-section narrative). Without this, backends use their own default (often 512 tokens), causing mid-sentence truncation of the diagnosis output. UI fix: reasoning card changed from bg-accent/5 border-accent/30 (opacity modifiers on CSS variables don't compose reliably across themes) to the callout pattern: bg-surface-raised with a solid border-l-4 border-accent. Header label changed from text-text-dim to text-accent for visual anchoring. Text remains text-text-primary for guaranteed contrast on both light and dark themes. Tracks: #56 (technical-level post-processor, filed as follow-on feature)
160 lines
5.5 KiB
Python
160 lines
5.5 KiB
Python
"""Shared LLM client for the multi-agent diagnose pipeline.
|
|
|
|
Both Stage 3 (RootCauseHypothesizer) and Stage 5 (SummarySynthesizer) send
|
|
messages to the same LLM backend using the same two-step pattern:
|
|
1. Try the cf-orch task endpoint → product-scoped inference routing.
|
|
2. Fall back to OpenAI-compat → direct model call by name.
|
|
|
|
Centralising here means changes to auth headers, timeouts, retry logic, or
|
|
cf-orch payload structure only need to be made once.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import re
|
|
|
|
import httpx
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Regex that strips ```json … ``` or ``` … ``` fences from LLM output.
|
|
_JSON_FENCE_RE = re.compile(
|
|
r"^```(?:json)?\s*|\s*```$",
|
|
re.MULTILINE,
|
|
)
|
|
|
|
|
|
def extract_content(resp_json: dict) -> str | None:
|
|
"""Pull text content from an OpenAI-compat chat completion response.
|
|
|
|
Returns None when the response has no choices or empty content.
|
|
"""
|
|
choices = resp_json.get("choices") or []
|
|
if not choices:
|
|
return None
|
|
return (choices[0].get("message", {}).get("content") or "").strip() or None
|
|
|
|
|
|
def strip_json_fences(raw: str) -> str:
|
|
"""Remove markdown code fences that some LLMs wrap around JSON output.
|
|
|
|
Example: '```json\\n[...]\\n```' → '[...]'
|
|
"""
|
|
return _JSON_FENCE_RE.sub("", raw).strip()
|
|
|
|
|
|
def extract_first_json_array(raw: str) -> str:
|
|
"""Extract the first complete JSON array from a string.
|
|
|
|
Reasoning models (e.g. foundation-sec-8b) sometimes emit valid JSON and
|
|
then repeat it inside a markdown fence. Standard json.loads() fails on the
|
|
combined text. This function scans for the first '[' and walks to its
|
|
matching ']', handling nested structures.
|
|
|
|
Returns the extracted substring, or the original string if no array found
|
|
(so the caller's json.loads() fails with the usual error message).
|
|
"""
|
|
start = raw.find("[")
|
|
if start == -1:
|
|
return raw
|
|
|
|
depth = 0
|
|
in_string = False
|
|
escape_next = False
|
|
|
|
for i, ch in enumerate(raw[start:], start=start):
|
|
if escape_next:
|
|
escape_next = False
|
|
continue
|
|
if ch == "\\" and in_string:
|
|
escape_next = True
|
|
continue
|
|
if ch == '"':
|
|
in_string = not in_string
|
|
continue
|
|
if in_string:
|
|
continue
|
|
if ch == "[":
|
|
depth += 1
|
|
elif ch == "]":
|
|
depth -= 1
|
|
if depth == 0:
|
|
return raw[start : i + 1]
|
|
|
|
return raw # unbalanced — return as-is so caller sees the error
|
|
|
|
|
|
def call_llm(
|
|
llm_url: str,
|
|
llm_model: str,
|
|
llm_api_key: str | None,
|
|
messages: list[dict],
|
|
task_name: str = "log_analysis",
|
|
timeout: float = 120.0,
|
|
max_tokens: int = 2048,
|
|
) -> str | None:
|
|
"""Send messages to the LLM; return raw text or None on failure.
|
|
|
|
Tries the cf-orch task endpoint first (product-routed inference).
|
|
Falls back to a direct OpenAI-compat ``/v1/chat/completions`` call when:
|
|
- The task endpoint returns 404 (no assignment for this task).
|
|
- The task endpoint is unreachable (connection error, timeout, etc.).
|
|
|
|
Args:
|
|
llm_url: Base URL of the LLM backend (e.g. ``http://<YOUR_HOST_IP>:7700``).
|
|
llm_model: Model identifier used in the OpenAI-compat fallback call.
|
|
llm_api_key: Optional bearer token for authenticated endpoints.
|
|
messages: OpenAI-style message list (system + user turns).
|
|
task_name: cf-orch task name for product-routed inference (default: ``log_analysis``).
|
|
timeout: Request timeout in seconds (default: 120).
|
|
max_tokens: Maximum tokens to generate (default: 2048). Prevents mid-sentence
|
|
truncation when the backend default is lower than the output needs.
|
|
|
|
Returns:
|
|
Raw text content string, or None if both paths fail.
|
|
"""
|
|
headers: dict[str, str] = {}
|
|
if llm_api_key:
|
|
headers["Authorization"] = f"Bearer {llm_api_key}"
|
|
|
|
# --- Path 1: cf-orch task endpoint ---
|
|
task_url = f"{llm_url.rstrip('/')}/api/inference/task"
|
|
try:
|
|
resp = httpx.post(
|
|
task_url,
|
|
json={
|
|
"product": "turnstone",
|
|
"task": task_name,
|
|
"payload": {"messages": messages, "stream": False, "max_tokens": max_tokens},
|
|
},
|
|
headers=headers,
|
|
timeout=timeout,
|
|
)
|
|
if resp.status_code == 200:
|
|
return extract_content(resp.json())
|
|
if resp.status_code != 404:
|
|
resp.raise_for_status()
|
|
logger.debug(
|
|
"No task assignment for turnstone.%s — falling back to direct model",
|
|
task_name,
|
|
)
|
|
except Exception as exc: # noqa: BLE001
|
|
# Broad catch is intentional: captures network errors, timeouts, and
|
|
# any backend-specific exceptions so the pipeline can fall back.
|
|
logger.debug(
|
|
"Task endpoint unavailable (%s) — falling back to direct model", exc
|
|
)
|
|
|
|
# --- Path 2: OpenAI-compat fallback ---
|
|
try:
|
|
resp = httpx.post(
|
|
f"{llm_url.rstrip('/')}/v1/chat/completions",
|
|
json={"model": llm_model, "messages": messages, "stream": False, "max_tokens": max_tokens},
|
|
headers=headers,
|
|
timeout=timeout,
|
|
)
|
|
resp.raise_for_status()
|
|
return extract_content(resp.json())
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.warning("LLM call failed (%s): %s", type(exc).__name__, exc)
|
|
return None
|