"""LiteLLM wrapper for multi-provider AI support.""" import json import logging import os import re from contextlib import asynccontextmanager from contextvars import ContextVar from dataclasses import dataclass from typing import Any import httpx import litellm from pydantic import BaseModel from app.config import settings # LLM timeout configuration (seconds) - base values LLM_TIMEOUT_HEALTH_CHECK = 30 LLM_TIMEOUT_COMPLETION = 120 LLM_TIMEOUT_JSON = 180 # JSON completions may take longer # LLM-004: OpenRouter JSON-capable models (explicit allowlist) OPENROUTER_JSON_CAPABLE_MODELS = { # Anthropic models "anthropic/claude-3-opus", "anthropic/claude-3-sonnet", "anthropic/claude-3-haiku", "anthropic/claude-3.5-sonnet", "anthropic/claude-3.5-haiku", "anthropic/claude-haiku-4-5-20251001", "anthropic/claude-sonnet-4-20250514", "anthropic/claude-opus-4-20250514", # OpenAI models "openai/gpt-4-turbo", "openai/gpt-4", "openai/gpt-4o", "openai/gpt-4o-mini", "openai/gpt-3.5-turbo", "openai/gpt-5-nano-2025-08-07", # Google models "google/gemini-pro", "google/gemini-1.5-pro", "google/gemini-1.5-flash", "google/gemini-2.0-flash", "google/gemini-3-flash-preview", # DeepSeek models "deepseek/deepseek-chat", "deepseek/deepseek-reasoner", # Mistral models "mistralai/mistral-large", "mistralai/mistral-medium", } # JSON-010: JSON extraction safety limits MAX_JSON_EXTRACTION_RECURSION = 10 MAX_JSON_CONTENT_SIZE = 1024 * 1024 # 1MB # Request-scoped user_id — set once by session_middleware_dep, read inside _allocate_orch_async. # ContextVar is safe for concurrent async requests: each request task gets its own copy. _request_user_id: ContextVar[str | None] = ContextVar("request_user_id", default=None) def set_request_user_id(user_id: str | None) -> None: _request_user_id.set(user_id) def get_request_user_id() -> str | None: return _request_user_id.get() class LLMConfig(BaseModel): """LLM configuration model.""" provider: str model: str api_key: str api_base: str | None = None @dataclass class _OrchAllocation: allocation_id: str url: str service: str @asynccontextmanager async def _allocate_orch_async( coordinator_url: str, service: str, model_candidates: list[str], ttl_s: float, caller: str, ): """Async context manager that allocates a cf-orch service and releases on exit.""" async with httpx.AsyncClient(timeout=120.0) as client: payload: dict[str, Any] = { "model_candidates": model_candidates, "ttl_s": ttl_s, "caller": caller, } uid = get_request_user_id() if uid: payload["user_id"] = uid resp = await client.post( f"{coordinator_url.rstrip('/')}/api/services/{service}/allocate", json=payload, ) if not resp.is_success: raise RuntimeError( f"cf-orch allocation failed for {service!r}: " f"HTTP {resp.status_code} — {resp.text[:200]}" ) data = resp.json() alloc = _OrchAllocation( allocation_id=data["allocation_id"], url=data["url"], service=service, ) try: yield alloc finally: try: await client.delete( f"{coordinator_url.rstrip('/')}/api/services/{service}/allocations/{alloc.allocation_id}", timeout=10.0, ) except Exception as exc: logging.debug("cf-orch release failed (non-fatal): %s", exc) def _normalize_api_base(provider: str, api_base: str | None) -> str | None: """Normalize api_base for LiteLLM provider-specific expectations. When using proxies/aggregators, users often paste a base URL that already includes a version segment (e.g., `/v1`). Some LiteLLM provider handlers append those segments internally, which can lead to duplicated paths like `/v1/v1/...` and cause 404s. """ if not api_base: return None base = api_base.strip() if not base: return None base = base.rstrip("/") # Anthropic handler appends '/v1/messages'. If base already ends with '/v1', # strip it to avoid '/v1/v1/messages'. if provider == "anthropic" and base.endswith("/v1"): base = base[: -len("/v1")].rstrip("/") # Gemini handler appends '/v1/models/...'. If base already ends with '/v1', # strip it to avoid '/v1/v1/models/...'. if provider == "gemini" and base.endswith("/v1"): base = base[: -len("/v1")].rstrip("/") return base or None def _extract_text_parts(value: Any, depth: int = 0, max_depth: int = 10) -> list[str]: """Recursively extract text segments from nested response structures. Handles strings, lists, dicts with 'text'/'content'/'value' keys, and objects with text/content attributes. Limits recursion depth to avoid cycles. Args: value: Input value that may contain text in strings, lists, dicts, or objects. depth: Current recursion depth. max_depth: Maximum recursion depth before returning no content. Returns: A list of extracted text segments. """ if depth >= max_depth: return [] if value is None: return [] if isinstance(value, str): return [value] if isinstance(value, list): parts: list[str] = [] next_depth = depth + 1 for item in value: parts.extend(_extract_text_parts(item, next_depth, max_depth)) return parts if isinstance(value, dict): next_depth = depth + 1 if "text" in value: return _extract_text_parts(value.get("text"), next_depth, max_depth) if "content" in value: return _extract_text_parts(value.get("content"), next_depth, max_depth) if "value" in value: return _extract_text_parts(value.get("value"), next_depth, max_depth) return [] next_depth = depth + 1 if hasattr(value, "text"): return _extract_text_parts(getattr(value, "text"), next_depth, max_depth) if hasattr(value, "content"): return _extract_text_parts(getattr(value, "content"), next_depth, max_depth) return [] def _join_text_parts(parts: list[str]) -> str | None: """Join text parts with newlines, filtering empty strings. Args: parts: Candidate text segments. Returns: Joined string or None if the result is empty. """ joined = "\n".join(part for part in parts if part).strip() return joined or None def _extract_message_text(message: Any) -> str | None: """Extract plain text from a LiteLLM message object across providers.""" content: Any = None if hasattr(message, "content"): content = message.content elif isinstance(message, dict): content = message.get("content") return _join_text_parts(_extract_text_parts(content)) def _extract_choice_text(choice: Any) -> str | None: """Extract plain text from a LiteLLM choice object. Tries message.content first, then choice.text, then choice.delta. Handles both object attributes and dict keys. Args: choice: LiteLLM choice object or dict. Returns: Extracted text or None if no content is found. """ message: Any = None if hasattr(choice, "message"): message = choice.message elif isinstance(choice, dict): message = choice.get("message") content = _extract_message_text(message) if content: return content if hasattr(choice, "text"): content = _join_text_parts(_extract_text_parts(getattr(choice, "text"))) if content: return content if isinstance(choice, dict) and "text" in choice: content = _join_text_parts(_extract_text_parts(choice.get("text"))) if content: return content if hasattr(choice, "delta"): content = _join_text_parts(_extract_text_parts(getattr(choice, "delta"))) if content: return content if isinstance(choice, dict) and "delta" in choice: content = _join_text_parts(_extract_text_parts(choice.get("delta"))) if content: return content return None def _to_code_block(content: str | None, language: str = "text") -> str: """Wrap content in a markdown code block for client display.""" text = (content or "").strip() if not text: text = "" return f"```{language}\n{text}\n```" def _load_stored_config() -> dict: """Load config from config.json file.""" config_path = settings.config_path if config_path.exists(): try: return json.loads(config_path.read_text()) except (json.JSONDecodeError, OSError): return {} return {} def get_llm_config() -> LLMConfig: """Get current LLM configuration. Priority: config.json file > environment variables/settings """ stored = _load_stored_config() return LLMConfig( provider=stored.get("provider", settings.llm_provider), model=stored.get("model", settings.llm_model), api_key=stored.get("api_key", settings.llm_api_key), api_base=stored.get("api_base", settings.llm_api_base), ) def get_model_name(config: LLMConfig) -> str: """Convert provider/model to LiteLLM format. For most providers, adds the provider prefix if not already present. For OpenRouter, always adds 'openrouter/' prefix since OpenRouter models use nested prefixes like 'openrouter/anthropic/claude-3.5-sonnet'. """ provider_prefixes = { "openai": "", # OpenAI models don't need prefix "anthropic": "anthropic/", "openrouter": "openrouter/", "gemini": "gemini/", "deepseek": "deepseek/", "ollama": "ollama/", } prefix = provider_prefixes.get(config.provider, "") # OpenRouter is special: always add openrouter/ prefix unless already present # OpenRouter models use nested format: openrouter/anthropic/claude-3.5-sonnet if config.provider == "openrouter": if config.model.startswith("openrouter/"): return config.model return f"openrouter/{config.model}" # For other providers, don't add prefix if model already has a known prefix known_prefixes = ["openrouter/", "anthropic/", "gemini/", "deepseek/", "ollama/"] if any(config.model.startswith(p) for p in known_prefixes): return config.model # Add provider prefix for models that need it return f"{prefix}{config.model}" if prefix else config.model def _supports_temperature(provider: str, model: str) -> bool: """Return whether passing `temperature` is supported for this model/provider combo. Some models (e.g., OpenAI gpt-5 family) reject temperature values other than 1, and LiteLLM may error when temperature is passed. """ _ = provider model_lower = model.lower() if "gpt-5" in model_lower: return False return True def _get_reasoning_effort(provider: str, model: str) -> str | None: """Return a default reasoning_effort for models that require it. Some OpenAI gpt-5 models may return empty message.content unless a supported `reasoning_effort` is explicitly set. This keeps downstream JSON parsing reliable. """ _ = provider model_lower = model.lower() if "gpt-5" in model_lower: return "minimal" return None async def check_llm_health( config: LLMConfig | None = None, *, include_details: bool = False, test_prompt: str | None = None, ) -> dict[str, Any]: """Check if the LLM provider is accessible and working.""" if config is None: config = get_llm_config() # Check if API key is configured (except for Ollama) if config.provider != "ollama" and not config.api_key: return { "healthy": False, "provider": config.provider, "model": config.model, "error_code": "api_key_missing", } model_name = get_model_name(config) prompt = test_prompt or "Hi" try: # Make a minimal test call with timeout # Pass API key directly to avoid race conditions with global os.environ kwargs: dict[str, Any] = { "model": model_name, "messages": [{"role": "user", "content": prompt}], "max_tokens": 16, "api_key": config.api_key, "api_base": _normalize_api_base(config.provider, config.api_base), "timeout": LLM_TIMEOUT_HEALTH_CHECK, } reasoning_effort = _get_reasoning_effort(config.provider, model_name) if reasoning_effort: kwargs["reasoning_effort"] = reasoning_effort response = await litellm.acompletion(**kwargs) content = _extract_choice_text(response.choices[0]) if not content: # LLM-003: Empty response should mark health check as unhealthy logging.warning( "LLM health check returned empty content", extra={"provider": config.provider, "model": config.model}, ) result: dict[str, Any] = { "healthy": False, # Fixed: empty content means unhealthy "provider": config.provider, "model": config.model, "response_model": response.model if response else None, "error_code": "empty_content", # Changed from warning_code "message": "LLM returned empty response", } if include_details: result["test_prompt"] = _to_code_block(prompt) result["model_output"] = _to_code_block(None) return result result = { "healthy": True, "provider": config.provider, "model": config.model, "response_model": response.model if response else None, } if include_details: result["test_prompt"] = _to_code_block(prompt) result["model_output"] = _to_code_block(content) return result except Exception as e: # Log full exception details server-side, but do not expose them to clients logging.exception( "LLM health check failed", extra={"provider": config.provider, "model": config.model}, ) # Provide a minimal, actionable client-facing hint without leaking secrets. error_code = "health_check_failed" message = str(e) if "404" in message and "/v1/v1/" in message: error_code = "duplicate_v1_path" elif "404" in message: error_code = "not_found_404" elif " str: """Make a completion request to the LLM.""" if config is None: cf_orch_url = os.environ.get("CF_ORCH_URL", "").strip() if cf_orch_url: try: async with _allocate_orch_async( cf_orch_url, "vllm", model_candidates=["Qwen2.5-3B-Instruct"], ttl_s=300.0, caller="peregrine-resume-matcher", ) as alloc: orch_config = LLMConfig( provider="openai", model="__auto__", api_key="any", api_base=alloc.url.rstrip("/") + "/v1", ) return await complete(prompt, system_prompt, orch_config, max_tokens, temperature) except Exception as exc: logging.warning("cf-orch allocation failed, falling back to default config: %s", exc) config = get_llm_config() model_name = get_model_name(config) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) try: # Pass API key directly to avoid race conditions with global os.environ kwargs: dict[str, Any] = { "model": model_name, "messages": messages, "max_tokens": max_tokens, "api_key": config.api_key, "api_base": _normalize_api_base(config.provider, config.api_base), "timeout": LLM_TIMEOUT_COMPLETION, } if _supports_temperature(config.provider, model_name): kwargs["temperature"] = temperature reasoning_effort = _get_reasoning_effort(config.provider, model_name) if reasoning_effort: kwargs["reasoning_effort"] = reasoning_effort response = await litellm.acompletion(**kwargs) content = _extract_choice_text(response.choices[0]) if not content: raise ValueError("Empty response from LLM") return content except Exception as e: # Log the actual error server-side for debugging logging.error(f"LLM completion failed: {e}", extra={"model": model_name}) raise ValueError( "LLM completion failed. Please check your API configuration and try again." ) from e def _supports_json_mode(provider: str, model: str) -> bool: """Check if the model supports JSON mode.""" # Models that support response_format={"type": "json_object"} json_mode_providers = ["openai", "anthropic", "gemini", "deepseek"] if provider in json_mode_providers: return True # LLM-004: OpenRouter models - use explicit allowlist instead of substring matching if provider == "openrouter": return model in OPENROUTER_JSON_CAPABLE_MODELS return False def _appears_truncated(data: dict) -> bool: """LLM-001: Check if JSON data appears to be truncated. Detects suspicious patterns indicating incomplete responses. """ if not isinstance(data, dict): return False # Check for empty arrays that should typically have content suspicious_empty_arrays = ["workExperience", "education", "skills"] for key in suspicious_empty_arrays: if key in data and data[key] == []: # Log warning - these are rarely empty in real resumes logging.warning( "Possible truncation detected: '%s' is empty", key, ) return True # Check for missing critical sections required_top_level = ["personalInfo"] for key in required_top_level: if key not in data: logging.warning( "Possible truncation detected: missing required section '%s'", key, ) return True return False def _get_retry_temperature(attempt: int, base_temp: float = 0.1) -> float: """LLM-002: Get temperature for retry attempt - increases with each retry. Higher temperature on retries gives the model more variation to produce different (hopefully valid) output. """ temperatures = [base_temp, 0.3, 0.5, 0.7] return temperatures[min(attempt, len(temperatures) - 1)] def _calculate_timeout( operation: str, max_tokens: int = 4096, provider: str = "openai", ) -> int: """LLM-005: Calculate adaptive timeout based on operation and parameters.""" base_timeouts = { "health_check": LLM_TIMEOUT_HEALTH_CHECK, "completion": LLM_TIMEOUT_COMPLETION, "json": LLM_TIMEOUT_JSON, } base = base_timeouts.get(operation, LLM_TIMEOUT_COMPLETION) # Scale by token count (relative to 4096 baseline) token_factor = max(1.0, max_tokens / 4096) # Provider-specific latency adjustments provider_factors = { "openai": 1.0, "anthropic": 1.2, "openrouter": 1.5, # More variable latency "ollama": 2.0, # Local models can be slower } provider_factor = provider_factors.get(provider, 1.0) return int(base * token_factor * provider_factor) def _extract_json(content: str, _depth: int = 0) -> str: """Extract JSON from LLM response, handling various formats. LLM-001: Improved to detect and reject likely truncated JSON. LLM-007: Improved error messages for debugging. JSON-010: Added recursion depth and size limits. """ # JSON-010: Safety limits if _depth > MAX_JSON_EXTRACTION_RECURSION: raise ValueError(f"JSON extraction exceeded max recursion depth: {_depth}") if len(content) > MAX_JSON_CONTENT_SIZE: raise ValueError(f"Content too large for JSON extraction: {len(content)} bytes") original = content # Remove markdown code blocks if "```json" in content: content = content.split("```json")[1].split("```")[0] elif "```" in content: parts = content.split("```") if len(parts) >= 2: content = parts[1] # Remove language identifier if present (e.g., "json\n{...") if content.startswith(("json", "JSON")): content = content[4:] content = content.strip() # If content starts with {, find the matching } if content.startswith("{"): depth = 0 end_idx = -1 in_string = False escape_next = False for i, char in enumerate(content): if escape_next: escape_next = False continue if char == "\\": escape_next = True continue if char == '"' and not escape_next: in_string = not in_string continue if in_string: continue if char == "{": depth += 1 elif char == "}": depth -= 1 if depth == 0: end_idx = i break # LLM-001: Check for unbalanced braces - loop ended without depth reaching 0 if end_idx == -1 and depth != 0: logging.warning( "JSON extraction found unbalanced braces (depth=%d), possible truncation", depth, ) if end_idx != -1: return content[: end_idx + 1] # Try to find JSON object in the content (only if not already at start) start_idx = content.find("{") if start_idx > 0: # Only recurse if { is found after position 0 to avoid infinite recursion return _extract_json(content[start_idx:], _depth + 1) # LLM-007: Log unrecognized format for debugging logging.error( "Could not extract JSON from response format. Content preview: %s", content[:200] if content else "", ) raise ValueError(f"No JSON found in response: {original[:200]}") async def complete_json( prompt: str, system_prompt: str | None = None, config: LLMConfig | None = None, max_tokens: int = 4096, retries: int = 2, ) -> dict[str, Any]: """Make a completion request expecting JSON response. Uses JSON mode when available, with retry logic for reliability. """ if config is None: cf_orch_url = os.environ.get("CF_ORCH_URL", "").strip() if cf_orch_url: try: async with _allocate_orch_async( cf_orch_url, "vllm", model_candidates=["Qwen2.5-3B-Instruct"], ttl_s=300.0, caller="peregrine-resume-matcher", ) as alloc: orch_config = LLMConfig( provider="openai", model="__auto__", api_key="any", api_base=alloc.url.rstrip("/") + "/v1", ) return await complete_json(prompt, system_prompt, orch_config, max_tokens, retries) except Exception as exc: logging.warning("cf-orch allocation failed, falling back to default config: %s", exc) config = get_llm_config() model_name = get_model_name(config) # Build messages json_system = ( system_prompt or "" ) + "\n\nYou must respond with valid JSON only. No explanations, no markdown." messages = [ {"role": "system", "content": json_system}, {"role": "user", "content": prompt}, ] # Check if we can use JSON mode use_json_mode = _supports_json_mode(config.provider, config.model) last_error = None for attempt in range(retries + 1): try: # Build request kwargs # Pass API key directly to avoid race conditions with global os.environ kwargs: dict[str, Any] = { "model": model_name, "messages": messages, "max_tokens": max_tokens, "api_key": config.api_key, "api_base": _normalize_api_base(config.provider, config.api_base), "timeout": _calculate_timeout("json", max_tokens, config.provider), } if _supports_temperature(config.provider, model_name): # LLM-002: Increase temperature on retry for variation kwargs["temperature"] = _get_retry_temperature(attempt) reasoning_effort = _get_reasoning_effort(config.provider, model_name) if reasoning_effort: kwargs["reasoning_effort"] = reasoning_effort # Add JSON mode if supported if use_json_mode: kwargs["response_format"] = {"type": "json_object"} response = await litellm.acompletion(**kwargs) content = _extract_choice_text(response.choices[0]) if not content: raise ValueError("Empty response from LLM") logging.debug(f"LLM response (attempt {attempt + 1}): {content[:300]}") # Extract and parse JSON json_str = _extract_json(content) result = json.loads(json_str) # LLM-001: Check if parsed result appears truncated if isinstance(result, dict) and _appears_truncated(result): logging.warning( "Parsed JSON appears truncated, but proceeding with result" ) return result except json.JSONDecodeError as e: last_error = e logging.warning(f"JSON parse failed (attempt {attempt + 1}): {e}") if attempt < retries: # Add hint to prompt for retry messages[-1]["content"] = ( prompt + "\n\nIMPORTANT: Output ONLY a valid JSON object. Start with { and end with }." ) continue raise ValueError(f"Failed to parse JSON after {retries + 1} attempts: {e}") except Exception as e: last_error = e logging.warning(f"LLM call failed (attempt {attempt + 1}): {e}") if attempt < retries: continue raise raise ValueError(f"Failed after {retries + 1} attempts: {last_error}")