Files
vision/llm/main.py

211 lines
6.7 KiB
Python

from __future__ import annotations
import asyncio
import logging
import os
import shlex
import subprocess
import time
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, Optional
import httpx
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
logger = logging.getLogger("llm")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
MODEL_PATH = os.getenv("MODEL_PATH", "/models/Qwen3-1.7B-Instruct-Q4_K_M.gguf")
LLM_MODEL_NAME = os.getenv("LLM_MODEL_NAME", "qwen3-1.7b-instruct-q4_k_m")
LLM_CONTEXT_SIZE = int(os.getenv("LLM_CONTEXT_SIZE", "4096"))
LLM_THREADS = int(os.getenv("LLM_THREADS", "4"))
LLM_GPU_LAYERS = int(os.getenv("LLM_GPU_LAYERS", "0"))
LLAMA_SERVER_PORT = int(os.getenv("LLAMA_SERVER_PORT", "8081"))
LLM_STARTUP_TIMEOUT = float(os.getenv("LLM_STARTUP_TIMEOUT", "120"))
LLM_EXTRA_ARGS = os.getenv("LLM_EXTRA_ARGS", "")
_llama_process: subprocess.Popen[bytes] | None = None
_http_client: httpx.AsyncClient | None = None
def _upstream_base_url() -> str:
return f"http://127.0.0.1:{LLAMA_SERVER_PORT}"
def _ensure_http_client() -> httpx.AsyncClient:
if _http_client is None:
raise RuntimeError("HTTP client not initialised")
return _http_client
def _validate_model_path() -> None:
model_file = Path(MODEL_PATH)
if not model_file.is_file():
raise RuntimeError(f"model file not found at {MODEL_PATH}")
if not os.access(model_file, os.R_OK):
raise RuntimeError(f"model file is not readable at {MODEL_PATH}")
def _build_llama_command() -> list[str]:
command = [
"/usr/local/bin/llama-server",
"--host",
"127.0.0.1",
"--port",
str(LLAMA_SERVER_PORT),
"--model",
MODEL_PATH,
"--alias",
LLM_MODEL_NAME,
"--ctx-size",
str(LLM_CONTEXT_SIZE),
"--threads",
str(LLM_THREADS),
"--n-gpu-layers",
str(LLM_GPU_LAYERS),
]
if LLM_EXTRA_ARGS.strip():
command.extend(shlex.split(LLM_EXTRA_ARGS))
return command
def _llama_running() -> bool:
return _llama_process is not None and _llama_process.poll() is None
async def _wait_for_llama_ready() -> None:
deadline = time.monotonic() + LLM_STARTUP_TIMEOUT
last_error: Optional[Exception] = None
while time.monotonic() < deadline:
if _llama_process is not None and _llama_process.poll() is not None:
raise RuntimeError(f"llama-server exited with code {_llama_process.poll()}")
try:
response = await _ensure_http_client().get(f"{_upstream_base_url()}/v1/models", timeout=5)
if response.status_code == 200:
logger.info("llm service: llama-server ready")
return
except Exception as exc:
last_error = exc
await asyncio.sleep(1)
raise RuntimeError(f"llama-server did not become ready within {LLM_STARTUP_TIMEOUT}s: {last_error}")
async def _stop_llama_process() -> None:
global _llama_process
if _llama_process is None:
return
if _llama_process.poll() is None:
_llama_process.terminate()
try:
await asyncio.to_thread(_llama_process.wait, timeout=10)
except subprocess.TimeoutExpired:
_llama_process.kill()
await asyncio.to_thread(_llama_process.wait, timeout=5)
_llama_process = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global _http_client, _llama_process
_validate_model_path()
_http_client = httpx.AsyncClient(timeout=httpx.Timeout(120, connect=5))
command = _build_llama_command()
logger.info("llm service: starting llama-server model=%s ctx=%s threads=%s gpu_layers=%s upstream_port=%s", LLM_MODEL_NAME, LLM_CONTEXT_SIZE, LLM_THREADS, LLM_GPU_LAYERS, LLAMA_SERVER_PORT)
_llama_process = subprocess.Popen(command)
try:
await _wait_for_llama_ready()
yield
finally:
await _stop_llama_process()
if _http_client is not None:
await _http_client.aclose()
_http_client = None
app = FastAPI(title="Skinbase LLM Service", version="1.0.0", lifespan=lifespan)
def _health_payload(status: str) -> Dict[str, Any]:
return {
"status": status,
"model": Path(MODEL_PATH).name,
"model_alias": LLM_MODEL_NAME,
"context_size": LLM_CONTEXT_SIZE,
"threads": LLM_THREADS,
"gpu_layers": LLM_GPU_LAYERS,
}
async def _proxy_request(method: str, path: str, *, body: bytes | None = None) -> Dict[str, Any]:
if not _llama_running():
raise HTTPException(status_code=503, detail="llama-server is not running")
headers = {"content-type": "application/json"} if body is not None else None
try:
response = await _ensure_http_client().request(
method,
f"{_upstream_base_url()}{path}",
content=body,
headers=headers,
timeout=httpx.Timeout(120, connect=5),
)
except httpx.TimeoutException as exc:
raise HTTPException(status_code=504, detail=f"llama-server timed out: {exc}")
except httpx.RequestError as exc:
raise HTTPException(status_code=503, detail=f"llama-server unavailable: {exc}")
if response.status_code >= 400:
detail: Any
try:
detail = response.json()
except Exception:
detail = response.text[:1000]
raise HTTPException(status_code=response.status_code, detail=detail)
try:
return response.json()
except Exception as exc:
raise HTTPException(status_code=502, detail=f"llama-server returned invalid JSON: {exc}")
@app.exception_handler(HTTPException)
async def handle_http_exception(_: Request, exc: HTTPException):
return JSONResponse(status_code=exc.status_code, content={"error": {"code": "llm_service_error", "message": str(exc.detail)}})
@app.get("/health")
async def health():
if not _llama_running():
return JSONResponse(status_code=503, content=_health_payload("unavailable"))
try:
response = await _ensure_http_client().get(f"{_upstream_base_url()}/v1/models", timeout=5)
if response.status_code != 200:
return JSONResponse(status_code=503, content=_health_payload("degraded"))
except Exception:
return JSONResponse(status_code=503, content=_health_payload("degraded"))
return _health_payload("ok")
@app.get("/v1/models")
async def list_models():
return await _proxy_request("GET", "/v1/models")
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
body = await request.body()
return await _proxy_request("POST", "/v1/chat/completions", body=body)