feat(text): ollama + vllm proxy backends for cf-text
- backends/ollama.py: routes requests to a running Ollama instance via HTTP API - backends/vllm.py: routes requests to vllm's OpenAI-compatible API (/v1/chat/completions); cf-text holds no GPU memory in proxy mode
This commit is contained in:
parent
8d9b55ef8f
commit
146fe97227
2 changed files with 414 additions and 0 deletions
201
circuitforge_core/text/backends/ollama.py
Normal file
201
circuitforge_core/text/backends/ollama.py
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
# circuitforge_core/text/backends/ollama.py — Ollama proxy backend for cf-text
|
||||
#
|
||||
# Routes inference requests to a running Ollama instance via its HTTP API.
|
||||
# cf-text itself holds no GPU memory; Ollama manages the model and VRAM.
|
||||
#
|
||||
# Model path format: "ollama://<model-name>" e.g. "ollama://llama3.1:8b"
|
||||
# The "ollama://" prefix is stripped before forwarding to the API.
|
||||
#
|
||||
# Environment:
|
||||
# CF_TEXT_OLLAMA_URL Base URL of the Ollama server (default: http://localhost:11434)
|
||||
#
|
||||
# MIT licensed.
|
||||
from __future__ import annotations
|
||||
|
||||
import json as _json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import AsyncIterator, Iterator
|
||||
|
||||
import httpx
|
||||
|
||||
from circuitforge_core.text.backends.base import GenerateResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
|
||||
class OllamaBackend:
|
||||
"""
|
||||
cf-text backend that proxies inference to a local Ollama instance.
|
||||
|
||||
This backend holds no GPU memory itself — Ollama owns the model and VRAM.
|
||||
vram_mb is therefore reported as 0 so cf-orch does not double-count VRAM
|
||||
against the separate ollama service budget.
|
||||
|
||||
Supports /generate, /chat, and /v1/chat/completions (via generate/chat).
|
||||
Streaming is implemented for all variants.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str, *, vram_mb: int = 0) -> None:
|
||||
# Strip the "ollama://" prefix from catalog paths
|
||||
self._model = model_path.removeprefix("ollama://")
|
||||
self._url = os.environ.get("CF_TEXT_OLLAMA_URL", _DEFAULT_OLLAMA_URL).rstrip("/")
|
||||
self._vram_mb = vram_mb
|
||||
logger.info("OllamaBackend: model=%r url=%r", self._model, self._url)
|
||||
|
||||
# ── Protocol properties ───────────────────────────────────────────────────
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
# Ollama manages its own VRAM; cf-text holds nothing.
|
||||
return self._vram_mb
|
||||
|
||||
# ── Synchronous interface ─────────────────────────────────────────────────
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
t0 = time.monotonic()
|
||||
payload: dict = {
|
||||
"model": self._model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": temperature, "num_predict": max_tokens},
|
||||
}
|
||||
if stop:
|
||||
payload["options"]["stop"] = stop
|
||||
with httpx.Client(timeout=180.0) as client:
|
||||
resp = client.post(f"{self._url}/api/generate", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
elapsed_ms = round((time.monotonic() - t0) * 1000)
|
||||
return GenerateResult(
|
||||
text=data.get("response", ""),
|
||||
tokens_used=data.get("eval_count", 0),
|
||||
model=self._model,
|
||||
)
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> Iterator[str]:
|
||||
payload: dict = {
|
||||
"model": self._model,
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
"options": {"temperature": temperature, "num_predict": max_tokens},
|
||||
}
|
||||
if stop:
|
||||
payload["options"]["stop"] = stop
|
||||
with httpx.Client(timeout=180.0) as client:
|
||||
with client.stream("POST", f"{self._url}/api/generate", json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
for line in resp.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
chunk = _json.loads(line)
|
||||
token = chunk.get("response", "")
|
||||
if token:
|
||||
yield token
|
||||
if chunk.get("done"):
|
||||
break
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> GenerateResult:
|
||||
t0 = time.monotonic()
|
||||
payload: dict = {
|
||||
"model": self._model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"options": {"temperature": temperature, "num_predict": max_tokens},
|
||||
}
|
||||
with httpx.Client(timeout=180.0) as client:
|
||||
resp = client.post(f"{self._url}/api/chat", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
elapsed_ms = round((time.monotonic() - t0) * 1000)
|
||||
return GenerateResult(
|
||||
text=data.get("message", {}).get("content", ""),
|
||||
tokens_used=data.get("eval_count", 0),
|
||||
model=self._model,
|
||||
)
|
||||
|
||||
# ── Async interface ───────────────────────────────────────────────────────
|
||||
|
||||
async def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
t0 = time.monotonic()
|
||||
payload: dict = {
|
||||
"model": self._model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": temperature, "num_predict": max_tokens},
|
||||
}
|
||||
if stop:
|
||||
payload["options"]["stop"] = stop
|
||||
async with httpx.AsyncClient(timeout=180.0) as client:
|
||||
resp = await client.post(f"{self._url}/api/generate", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
elapsed_ms = round((time.monotonic() - t0) * 1000)
|
||||
return GenerateResult(
|
||||
text=data.get("response", ""),
|
||||
tokens_used=data.get("eval_count", 0),
|
||||
model=self._model,
|
||||
)
|
||||
|
||||
async def generate_stream_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
payload: dict = {
|
||||
"model": self._model,
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
"options": {"temperature": temperature, "num_predict": max_tokens},
|
||||
}
|
||||
if stop:
|
||||
payload["options"]["stop"] = stop
|
||||
async with httpx.AsyncClient(timeout=180.0) as client:
|
||||
async with client.stream("POST", f"{self._url}/api/generate", json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
chunk = _json.loads(line)
|
||||
token = chunk.get("response", "")
|
||||
if token:
|
||||
yield token
|
||||
if chunk.get("done"):
|
||||
break
|
||||
213
circuitforge_core/text/backends/vllm.py
Normal file
213
circuitforge_core/text/backends/vllm.py
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
# circuitforge_core/text/backends/vllm.py — vllm proxy backend for cf-text
|
||||
#
|
||||
# Routes inference requests to a running vllm instance via its OpenAI-compatible
|
||||
# HTTP API (/v1/chat/completions, /v1/completions).
|
||||
# cf-text itself holds no GPU memory; vllm manages the model and VRAM.
|
||||
#
|
||||
# Model path format: "vllm://<model-id>" e.g. "vllm://Qwen/Qwen2.5-7B-Instruct"
|
||||
# The "vllm://" prefix is stripped; the remainder is the model_id sent to vllm.
|
||||
#
|
||||
# Environment:
|
||||
# CF_TEXT_VLLM_URL Base URL of the vllm server (default: http://localhost:8000)
|
||||
#
|
||||
# MIT licensed.
|
||||
from __future__ import annotations
|
||||
|
||||
import json as _json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import AsyncIterator, Iterator
|
||||
|
||||
import httpx
|
||||
|
||||
from circuitforge_core.text.backends.base import ChatMessage, GenerateResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_VLLM_URL = "http://localhost:8000"
|
||||
|
||||
|
||||
class VllmBackend:
|
||||
"""
|
||||
cf-text backend that proxies inference to a local vllm instance.
|
||||
|
||||
vllm exposes an OpenAI-compatible API (/v1/chat/completions).
|
||||
This backend holds no GPU memory — vllm owns the model and VRAM.
|
||||
vram_mb is reported as 0 so cf-orch does not double-count VRAM
|
||||
against the separate vllm service budget.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str, *, vram_mb: int = 0) -> None:
|
||||
# Strip the "vllm://" prefix from catalog paths
|
||||
self._model = model_path.removeprefix("vllm://")
|
||||
self._url = os.environ.get("CF_TEXT_VLLM_URL", _DEFAULT_VLLM_URL).rstrip("/")
|
||||
self._vram_mb = vram_mb
|
||||
logger.info("VllmBackend: model=%r url=%r", self._model, self._url)
|
||||
|
||||
# ── Protocol properties ───────────────────────────────────────────────────
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def vram_mb(self) -> int:
|
||||
# vllm manages its own VRAM; cf-text holds nothing.
|
||||
return self._vram_mb
|
||||
|
||||
# ── Internal helpers ──────────────────────────────────────────────────────
|
||||
|
||||
def _chat_payload(
|
||||
self,
|
||||
messages: list[dict],
|
||||
*,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
stop: list[str] | None,
|
||||
stream: bool,
|
||||
) -> dict:
|
||||
payload: dict = {
|
||||
"model": self._model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"stream": stream,
|
||||
}
|
||||
if stop:
|
||||
payload["stop"] = stop
|
||||
return payload
|
||||
|
||||
def _prompt_as_messages(self, prompt: str) -> list[dict]:
|
||||
return [{"role": "user", "content": prompt}]
|
||||
|
||||
# ── Synchronous interface ─────────────────────────────────────────────────
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
t0 = time.monotonic()
|
||||
payload = self._chat_payload(
|
||||
self._prompt_as_messages(prompt),
|
||||
max_tokens=max_tokens, temperature=temperature, stop=stop, stream=False,
|
||||
)
|
||||
with httpx.Client(timeout=180.0) as client:
|
||||
resp = client.post(f"{self._url}/v1/chat/completions", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return GenerateResult(
|
||||
text=data["choices"][0]["message"]["content"],
|
||||
tokens_used=data.get("usage", {}).get("completion_tokens", 0),
|
||||
model=self._model,
|
||||
)
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> Iterator[str]:
|
||||
payload = self._chat_payload(
|
||||
self._prompt_as_messages(prompt),
|
||||
max_tokens=max_tokens, temperature=temperature, stop=stop, stream=True,
|
||||
)
|
||||
with httpx.Client(timeout=180.0) as client:
|
||||
with client.stream("POST", f"{self._url}/v1/chat/completions", json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
for line in resp.iter_lines():
|
||||
token = _parse_sse_token(line)
|
||||
if token:
|
||||
yield token
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
) -> GenerateResult:
|
||||
dicts = [m.to_dict() if hasattr(m, "to_dict") else m for m in messages]
|
||||
payload = self._chat_payload(
|
||||
dicts, max_tokens=max_tokens, temperature=temperature, stop=None, stream=False,
|
||||
)
|
||||
with httpx.Client(timeout=180.0) as client:
|
||||
resp = client.post(f"{self._url}/v1/chat/completions", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return GenerateResult(
|
||||
text=data["choices"][0]["message"]["content"],
|
||||
tokens_used=data.get("usage", {}).get("completion_tokens", 0),
|
||||
model=self._model,
|
||||
)
|
||||
|
||||
# ── Async interface ───────────────────────────────────────────────────────
|
||||
|
||||
async def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> GenerateResult:
|
||||
payload = self._chat_payload(
|
||||
self._prompt_as_messages(prompt),
|
||||
max_tokens=max_tokens, temperature=temperature, stop=stop, stream=False,
|
||||
)
|
||||
async with httpx.AsyncClient(timeout=180.0) as client:
|
||||
resp = await client.post(f"{self._url}/v1/chat/completions", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return GenerateResult(
|
||||
text=data["choices"][0]["message"]["content"],
|
||||
tokens_used=data.get("usage", {}).get("completion_tokens", 0),
|
||||
model=self._model,
|
||||
)
|
||||
|
||||
async def generate_stream_async(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
payload = self._chat_payload(
|
||||
self._prompt_as_messages(prompt),
|
||||
max_tokens=max_tokens, temperature=temperature, stop=stop, stream=True,
|
||||
)
|
||||
async with httpx.AsyncClient(timeout=180.0) as client:
|
||||
async with client.stream("POST", f"{self._url}/v1/chat/completions", json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
token = _parse_sse_token(line)
|
||||
if token:
|
||||
yield token
|
||||
|
||||
|
||||
# ── SSE parser (OpenAI/vllm format) ──────────────────────────────────────────
|
||||
|
||||
def _parse_sse_token(line: str) -> str:
|
||||
"""Extract content token from an OpenAI-format SSE line.
|
||||
|
||||
Lines look like: data: {"choices": [{"delta": {"content": "word"}}]}
|
||||
Terminal line: data: [DONE]
|
||||
Returns the token string, or "" for empty/done/non-data lines.
|
||||
"""
|
||||
if not line.startswith("data:"):
|
||||
return ""
|
||||
payload = line[5:].strip()
|
||||
if payload == "[DONE]":
|
||||
return ""
|
||||
try:
|
||||
chunk = _json.loads(payload)
|
||||
return chunk["choices"][0]["delta"].get("content", "") or ""
|
||||
except (KeyError, IndexError, _json.JSONDecodeError):
|
||||
return ""
|
||||
Loading…
Reference in a new issue