diff --git a/.env.example b/.env.example index 17b026a..e69a22c 100644 --- a/.env.example +++ b/.env.example @@ -7,6 +7,13 @@ CLIP_URL=http://clip:8000 BLIP_URL=http://blip:8000 YOLO_URL=http://yolo:8000 QDRANT_SVC_URL=http://qdrant-svc:8000 +LLM_URL=http://llm:8080 +LLM_ENABLED=false +LLM_TIMEOUT=120 +LLM_DEFAULT_MODEL=qwen3-1.7b-instruct-q4_k_m +LLM_MAX_TOKENS_DEFAULT=256 +LLM_MAX_TOKENS_HARD_LIMIT=1024 +LLM_MAX_REQUEST_BYTES=65536 # HuggingFace token for private/gated models (optional). Leave empty if unused. # Never commit a real token to this file. @@ -21,3 +28,10 @@ VECTOR_DIM=512 # Gateway runtime VISION_TIMEOUT=300 MAX_IMAGE_BYTES=52428800 + +# Local llama.cpp LLM service (only needed when you run the llm profile locally) +MODEL_PATH=/models/Qwen3-1.7B-Instruct-Q4_K_M.gguf +LLM_CONTEXT_SIZE=4096 +LLM_THREADS=4 +LLM_GPU_LAYERS=0 +LLM_EXTRA_ARGS= diff --git a/.gitignore b/.gitignore index 6988ef4..0a6bd0e 100644 --- a/.gitignore +++ b/.gitignore @@ -49,6 +49,7 @@ qdrant_data/ *.pth *.bin *.ckpt +*.gguf # Numpy arrays *.npy diff --git a/README.md b/README.md index 065a735..2387143 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ -# Skinbase Vision Stack (CLIP + BLIP + YOLO + Qdrant + Card Renderer + Maturity) – Dockerized FastAPI +# Skinbase Vision Stack (CLIP + BLIP + YOLO + Qdrant + Card Renderer + Maturity + LLM) – Dockerized FastAPI -This repository provides **six standalone vision services** (CLIP / BLIP / YOLO / Qdrant / Card Renderer / Maturity) -and a **Gateway API** that can call them individually or together. +This repository provides internal AI services for image analysis, vector search, card rendering, moderation, +and text generation behind a single **Gateway API**. ## Services & Ports @@ -13,6 +13,7 @@ and a **Gateway API** that can call them individually or together. - `qdrant-svc`: internal Qdrant API wrapper - `card-renderer`: internal card rendering service - `maturity`: internal NSFW/maturity classifier service +- `llm`: internal text-generation service using a thin FastAPI shim over `llama-server` (profile-based, internal only) ## Run @@ -20,6 +21,16 @@ and a **Gateway API** that can call them individually or together. docker compose up -d --build ``` +That starts the default vision stack only. The LLM service is disabled by default so operators are not forced to run Qwen3 on the same host. + +To also start the local llama.cpp service: + +```bash +docker compose --profile llm up -d --build +``` + +Before enabling the `llm` profile locally, place the GGUF model file described in [models/qwen3/README.md](models/qwen3/README.md) and set `LLM_ENABLED=true` in `.env`. + If you use BLIP, create a `.env` file first. Required variables: @@ -40,6 +51,26 @@ MATURITY_THRESHOLD_REVIEW=0.60 MATURITY_ENABLED=true ``` +Optional LLM configuration: + +```bash +LLM_ENABLED=false +LLM_URL=http://llm:8080 +LLM_DEFAULT_MODEL=qwen3-1.7b-instruct-q4_k_m +LLM_TIMEOUT=120 +LLM_MAX_TOKENS_DEFAULT=256 +LLM_MAX_TOKENS_HARD_LIMIT=1024 +LLM_MAX_REQUEST_BYTES=65536 + +# Local llm profile only +MODEL_PATH=/models/Qwen3-1.7B-Instruct-Q4_K_M.gguf +LLM_CONTEXT_SIZE=4096 +LLM_THREADS=4 +LLM_GPU_LAYERS=0 +``` + +Recommended production topology for the LLM: keep the gateway on the current vision host and point `LLM_URL` at a separate private machine or VPN-reachable container host. Running the full vision stack and Qwen3 together on a small 4c/8GB VPS will usually degrade both. + Service startup now waits on container healthchecks, so first boot may take longer while models finish loading. ## Health @@ -48,6 +79,71 @@ Service startup now waits on container healthchecks, so first boot may take long curl -H "X-API-Key: " https://vision.klevze.net/health ``` +LLM-specific gateway health: + +```bash +curl -H "X-API-Key: " https://vision.klevze.net/ai/health +``` + +## LLM Smoke Test + +Use this checklist on a Docker-capable host after provisioning the GGUF file and setting `LLM_ENABLED=true`. + +1. Start the gateway and local LLM profile. + +```bash +docker compose --profile llm up -d --build gateway llm +``` + +2. Confirm the LLM container is running and healthy. + +```bash +docker compose ps llm +docker compose logs --tail=100 llm +``` + +3. Check the internal LLM health contract. + +```bash +curl http://127.0.0.1:8080/health +``` + +Expected fields: `status`, `model`, `context_size`, `threads`. + +4. Check gateway health and LLM reachability. + +```bash +curl -H "X-API-Key: " http://127.0.0.1:8003/health +curl -H "X-API-Key: " http://127.0.0.1:8003/ai/health +``` + +5. Verify model discovery through the gateway. + +```bash +curl -H "X-API-Key: " http://127.0.0.1:8003/v1/models +``` + +6. Run a short non-streaming chat completion. + +```bash +curl -H "X-API-Key: " -X POST http://127.0.0.1:8003/ai/chat \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + {"role": "system", "content": "You are a concise assistant for Skinbase Nova."}, + {"role": "user", "content": "Write one sentence about an artist who creates cinematic sci-fi wallpaper packs."} + ], + "max_tokens": 80 + }' +``` + +7. If anything fails, inspect the two relevant services first. + +```bash +docker compose logs --tail=200 llm +docker compose logs --tail=200 gateway +``` + ## Universal analyze (ALL) ### With URL @@ -271,11 +367,51 @@ curl -H "X-API-Key: " -X POST https://vision.klevze.net/cards/rend -d '{"url":"https://files.skinbase.org/img/aa/bb/cc/md.webp","title":"Artwork Title"}' ``` +## LLM / Chat Completions + +The gateway exposes stable text-generation endpoints backed by the internal `llm` service. They reuse the existing `X-API-Key` protection and keep the LLM container internal-only. + +### OpenAI-style chat endpoint +```bash +curl -H "X-API-Key: " -X POST https://vision.klevze.net/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + {"role": "system", "content": "You are a concise assistant for Skinbase Nova."}, + {"role": "user", "content": "Write a short creator biography for an artist who just hit 10,000 followers."} + ], + "temperature": 0.7, + "max_tokens": 220, + "stream": false + }' +``` + +### Project-friendly chat endpoint +```bash +curl -H "X-API-Key: " -X POST https://vision.klevze.net/ai/chat \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + {"role": "system", "content": "You are a concise assistant for Skinbase Nova."}, + {"role": "user", "content": "Suggest metadata tags for a cyberpunk wallpaper pack."} + ], + "max_tokens": 180 + }' +``` + +### List models +```bash +curl -H "X-API-Key: " https://vision.klevze.net/v1/models +curl -H "X-API-Key: " https://vision.klevze.net/ai/models +``` + ## Notes - Models are loaded at service startup; initial container start can take 1–2 minutes as model weights are downloaded. - Qdrant data is persisted in the project folder at `./data/qdrant`, so it survives container restarts and recreates. +- The local `llm` profile does **not** auto-download Qwen3 weights. Mount the GGUF file explicitly and let startup fail fast if it is missing. - Remote image URLs are restricted to public `http`/`https` hosts. Localhost, private IP ranges, and non-image content types are rejected. - The maturity service uses `Falconsai/nsfw_image_detection` (ViT-based). Thresholds are configurable via `.env`. The model handles photos and stylized digital art but should be calibrated against real Skinbase content before production use. +- For small VPS deployments, prefer `LLM_ENABLED=true` with `LLM_URL` pointing to a separate LLM host instead of running the `llm` profile on the same machine. - For production: add auth, rate limits, and restrict gateway exposure (private network). - GPU: you can add NVIDIA runtime later (compose profiles) if needed. diff --git a/USAGE.md b/USAGE.md index ee03b89..d513409 100644 --- a/USAGE.md +++ b/USAGE.md @@ -1,10 +1,10 @@ # Skinbase Vision Stack — Usage Guide -This document explains how to run and use the Skinbase Vision Stack (Gateway + CLIP, BLIP, YOLO, Qdrant, Card Renderer, Maturity services). +This document explains how to run and use the Skinbase Vision Stack (Gateway + CLIP, BLIP, YOLO, Qdrant, Card Renderer, Maturity, and optional LLM services). ## Overview -- Services: `gateway`, `clip`, `blip`, `yolo`, `qdrant`, `qdrant-svc`, `card-renderer`, `maturity` (FastAPI each, except `qdrant` which is the official Qdrant DB). +- Services: `gateway`, `clip`, `blip`, `yolo`, `qdrant`, `qdrant-svc`, `card-renderer`, `maturity`, `llm` (FastAPI each except `qdrant`; `llm` is a thin FastAPI shim that manages an internal `llama-server` process). - Gateway is the public API endpoint; the other services are internal. ## Model overview @@ -21,6 +21,8 @@ This document explains how to run and use the Skinbase Vision Stack (Gateway + C - **Maturity**: Dedicated NSFW/maturity classifier. Accepts an image and returns a normalized safety signal including `maturity_label` (`safe`/`mature`), `confidence`, raw `score`, optional sublabels (e.g. `nsfw`), and an `action_hint` (`safe`, `review`, `flag_high`) designed for Nova moderation workflows. Powered by `Falconsai/nsfw_image_detection` (ViT-based, HuggingFace). Thresholds are configurable via environment variables. +- **LLM**: Internal text-generation service backed by `llama.cpp` and a GGUF Qwen3 model. Exposed through the gateway for non-streaming chat completions and model discovery. Intended for Nova workflows such as creator bios, metadata suggestions, moderation helper text, and other short internal generation tasks. + ## Prerequisites - Docker Desktop (with `docker compose`) or a Docker environment. @@ -55,12 +57,48 @@ MATURITY_ENABLED=true - `MATURITY_THRESHOLD_REVIEW`: score above this but below mature threshold → `mature` + `review` (default `0.60`). - `MATURITY_ENABLED`: set to `false` to disable maturity endpoints at the gateway without removing the service. +Optional LLM configuration: + +```bash +LLM_URL=http://llm:8080 +LLM_ENABLED=false +LLM_TIMEOUT=120 +LLM_DEFAULT_MODEL=qwen3-1.7b-instruct-q4_k_m +LLM_MAX_TOKENS_DEFAULT=256 +LLM_MAX_TOKENS_HARD_LIMIT=1024 +LLM_MAX_REQUEST_BYTES=65536 + +# Local llm profile only +MODEL_PATH=/models/Qwen3-1.7B-Instruct-Q4_K_M.gguf +LLM_CONTEXT_SIZE=4096 +LLM_THREADS=4 +LLM_GPU_LAYERS=0 +LLM_EXTRA_ARGS= +``` + Run from repository root: ```bash docker compose up -d --build ``` +That starts the default vision stack only. + +To also start the local LLM service: + +```bash +docker compose --profile llm up -d --build +``` + +Before enabling the `llm` profile, provision the GGUF model described in [models/qwen3/README.md](models/qwen3/README.md) and set `LLM_ENABLED=true` in `.env`. + +For small production hosts, the preferred setup is usually to keep the gateway local and point `LLM_URL` at a separate private LLM host: + +```bash +LLM_ENABLED=true +LLM_URL=http://private-llm-host:8080 +``` + Stop: ```bash @@ -82,6 +120,74 @@ Check the gateway health endpoint: curl https://vision.klevze.net/health ``` +Check LLM-specific gateway health: + +```bash +curl -H "X-API-Key: " https://vision.klevze.net/ai/health +``` + +## LLM smoke test checklist + +Use this sequence on a machine with Docker available after you have mounted the GGUF model and enabled the gateway with `LLM_ENABLED=true`. + +1. Start the gateway with the `llm` profile. + +```bash +docker compose --profile llm up -d --build gateway llm +``` + +2. Confirm the LLM service came up cleanly. + +```bash +docker compose ps llm +docker compose logs --tail=100 llm +``` + +3. Check the repo-owned internal health endpoint. + +```bash +curl http://127.0.0.1:8080/health +``` + +Expected fields: `status`, `model`, `context_size`, `threads`. + +4. Confirm the gateway sees the LLM backend. + +```bash +curl -H "X-API-Key: " http://127.0.0.1:8003/health +curl -H "X-API-Key: " http://127.0.0.1:8003/ai/health +``` + +5. Verify model discovery. + +```bash +curl -H "X-API-Key: " http://127.0.0.1:8003/v1/models +curl -H "X-API-Key: " http://127.0.0.1:8003/ai/models +``` + +6. Run a small chat request through the gateway. + +```bash +curl -X POST http://127.0.0.1:8003/v1/chat/completions \ + -H "X-API-Key: " \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + {"role": "system", "content": "You are a concise assistant for Skinbase Nova."}, + {"role": "user", "content": "Write one short admin help sentence about reviewing wallpaper metadata."} + ], + "max_tokens": 60, + "stream": false + }' +``` + +7. If startup or health fails, inspect the relevant logs. + +```bash +docker compose logs --tail=200 llm +docker compose logs --tail=200 gateway +``` + ## Universal analyze (ALL) Analyze an image by URL (gateway aggregates CLIP, BLIP, YOLO): @@ -241,7 +347,93 @@ Response fields: - `review`: score ≥ `MATURITY_THRESHOLD_REVIEW` (default 0.60) but below mature threshold — possible mature, queue for human review. - `safe`: score below both thresholds — content appears safe. -If the maturity service is unavailable the gateway returns a `502` or `503` error. **Nova must not treat a gateway failure as a `safe` result** — retry or queue for later processing. store image embeddings and find visually similar images. Embeddings are generated automatically by the CLIP service. +If the maturity service is unavailable the gateway returns a `502` or `503` error. **Nova must not treat a gateway failure as a `safe` result** — retry or queue for later processing. + +## LLM / Chat endpoints + +The gateway validates requests, clamps `max_tokens` to configured limits, rejects oversized payloads, and normalizes downstream failures into JSON under an `error` key. + +### OpenAI-style chat completions + +```bash +curl -X POST https://vision.klevze.net/v1/chat/completions \ + -H "X-API-Key: " \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + {"role": "system", "content": "You are a concise assistant for Skinbase Nova."}, + {"role": "user", "content": "Write a short biography for a creator known for sci-fi environments."} + ], + "temperature": 0.7, + "max_tokens": 220, + "stream": false + }' +``` + +Supported request fields: +- `messages` (required) +- `temperature` +- `max_tokens` +- `stream` (`false` only in v1) +- `top_p` +- `stop` +- `presence_penalty` +- `frequency_penalty` + +Validation rules: +- At least one message is required. +- Roles must be `system`, `user`, or `assistant`. +- Empty message content is rejected. +- Oversized request bodies return `413`. +- `max_tokens` is clamped to `LLM_MAX_TOKENS_HARD_LIMIT`. + +### Project-friendly chat response + +```bash +curl -X POST https://vision.klevze.net/ai/chat \ + -H "X-API-Key: " \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + {"role": "system", "content": "You are a helpful metadata assistant."}, + {"role": "user", "content": "Suggest five tags for a fantasy castle wallpaper."} + ] + }' +``` + +Example response: + +```json +{ + "model": "qwen3-1.7b-instruct-q4_k_m", + "content": "fantasy castle, moonlit fortress, medieval towers, epic landscape, digital painting", + "finish_reason": "stop", + "usage": { + "prompt_tokens": 48, + "completion_tokens": 19, + "total_tokens": 67 + } +} +``` + +### Model discovery + +```bash +curl -H "X-API-Key: " https://vision.klevze.net/v1/models +curl -H "X-API-Key: " https://vision.klevze.net/ai/models +``` + +### Failure modes + +- `401`: missing or invalid API key +- `413`: request body exceeds `LLM_MAX_REQUEST_BYTES` +- `422`: validation failure or unsupported streaming request +- `503`: LLM disabled or upstream unavailable +- `504`: upstream timeout + +## Vector DB (Qdrant) + +Use the Qdrant gateway endpoints to store image embeddings and find visually similar images. Embeddings are generated automatically by the CLIP service. Qdrant point IDs must be either an unsigned integer or a UUID string. If you send another string value, the wrapper may replace it with a generated UUID and store the original value in metadata as `_original_id`. diff --git a/docker-compose.yml b/docker-compose.yml index cd764cb..ba93bec 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -14,6 +14,13 @@ services: - QDRANT_SVC_URL=http://qdrant-svc:8000 - CARD_RENDERER_URL=http://card-renderer:8000 - MATURITY_URL=http://maturity:8000 + - LLM_URL=${LLM_URL:-http://llm:8080} + - LLM_ENABLED=${LLM_ENABLED:-false} + - LLM_TIMEOUT=${LLM_TIMEOUT:-120} + - LLM_DEFAULT_MODEL=${LLM_DEFAULT_MODEL:-qwen3-1.7b-instruct-q4_k_m} + - LLM_MAX_TOKENS_DEFAULT=${LLM_MAX_TOKENS_DEFAULT:-256} + - LLM_MAX_TOKENS_HARD_LIMIT=${LLM_MAX_TOKENS_HARD_LIMIT:-1024} + - LLM_MAX_REQUEST_BYTES=${LLM_MAX_REQUEST_BYTES:-65536} - MATURITY_ENABLED=true - API_KEY=${API_KEY} - VISION_TIMEOUT=300 @@ -151,3 +158,26 @@ services: retries: 5 start_period: 90s + llm: + build: + context: . + dockerfile: llm/Dockerfile + environment: + - MODEL_PATH=${MODEL_PATH:-/models/Qwen3-1.7B-Instruct-Q4_K_M.gguf} + - LLM_MODEL_NAME=${LLM_DEFAULT_MODEL:-qwen3-1.7b-instruct-q4_k_m} + - LLM_CONTEXT_SIZE=${LLM_CONTEXT_SIZE:-4096} + - LLM_THREADS=${LLM_THREADS:-4} + - LLM_GPU_LAYERS=${LLM_GPU_LAYERS:-0} + - LLM_PORT=8080 + - LLM_EXTRA_ARGS=${LLM_EXTRA_ARGS:-} + volumes: + - ./models/qwen3:/models:ro + healthcheck: + test: ["CMD", "curl", "-fsS", "http://127.0.0.1:8080/health"] + interval: 30s + timeout: 10s + retries: 5 + start_period: 120s + profiles: + - llm + diff --git a/gateway/main.py b/gateway/main.py index 59c7e01..3e2212f 100644 --- a/gateway/main.py +++ b/gateway/main.py @@ -1,17 +1,18 @@ from __future__ import annotations import asyncio +import json import logging import os import time from contextlib import asynccontextmanager -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Literal, Optional import httpx from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Request from fastapi.responses import JSONResponse, Response from starlette.middleware.base import BaseHTTPMiddleware -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ValidationError, field_validator logger = logging.getLogger("gateway") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s") @@ -23,6 +24,16 @@ QDRANT_SVC_URL = os.getenv("QDRANT_SVC_URL", "http://qdrant-svc:8000") CARD_RENDERER_URL = os.getenv("CARD_RENDERER_URL", "http://card-renderer:8000") MATURITY_URL = os.getenv("MATURITY_URL", "http://maturity:8000") MATURITY_ENABLED = os.getenv("MATURITY_ENABLED", "true").lower() not in ("0", "false", "no") +LLM_URL = os.getenv("LLM_URL", "http://llm:8080") +LLM_ENABLED = os.getenv("LLM_ENABLED", "false").lower() not in ("0", "false", "no") +LLM_DEFAULT_MODEL = os.getenv("LLM_DEFAULT_MODEL", "qwen3-1.7b-instruct-q4_k_m") +LLM_TIMEOUT = float(os.getenv("LLM_TIMEOUT", "120")) +LLM_MAX_TOKENS_HARD_LIMIT = max(1, int(os.getenv("LLM_MAX_TOKENS_HARD_LIMIT", "1024"))) +LLM_MAX_TOKENS_DEFAULT = min( + LLM_MAX_TOKENS_HARD_LIMIT, + max(1, int(os.getenv("LLM_MAX_TOKENS_DEFAULT", "256"))), +) +LLM_MAX_REQUEST_BYTES = max(1024, int(os.getenv("LLM_MAX_REQUEST_BYTES", "65536"))) VISION_TIMEOUT = float(os.getenv("VISION_TIMEOUT", "20")) # API key (set via env var `API_KEY`). If not set, gateway will reject requests. @@ -36,6 +47,21 @@ API_KEY = os.getenv("API_KEY") _http_client: httpx.AsyncClient | None = None +class LLMGatewayError(Exception): + def __init__( + self, + status_code: int, + code: str, + message: str, + details: Optional[Any] = None, + ): + self.status_code = status_code + self.code = code + self.message = message + self.details = details + super().__init__(message) + + def get_http_client() -> httpx.AsyncClient: """Return the shared httpx client. Raises if called before lifespan starts.""" if _http_client is None: @@ -74,6 +100,17 @@ async def lifespan(app: FastAPI): except Exception as exc: logger.warning("gateway startup: qdrant-svc warm ping failed (non-fatal): %s", exc) + if LLM_ENABLED: + try: + t_warm = time.perf_counter() + r = await _http_client.get(f"{LLM_URL}/health", timeout=min(LLM_TIMEOUT, 10)) + logger.info( + "gateway startup: llm warm ping done status=%s elapsed_ms=%.1f", + r.status_code, (time.perf_counter() - t_warm) * 1000, + ) + except Exception as exc: + logger.warning("gateway startup: llm warm ping failed (non-fatal): %s", exc) + logger.info("gateway startup complete elapsed_ms=%.1f", (time.perf_counter() - t0) * 1000) yield # application runs @@ -90,13 +127,31 @@ class APIKeyMiddleware(BaseHTTPMiddleware): return await call_next(request) key = request.headers.get("x-api-key") or request.headers.get("X-API-Key") if not API_KEY or key != API_KEY: + if _is_llm_path(request.url.path): + return JSONResponse( + status_code=401, + content={"error": {"code": "unauthorized", "message": "Unauthorized"}}, + ) return JSONResponse(status_code=401, content={"detail": "Unauthorized"}) return await call_next(request) + +def _is_llm_path(path: str) -> bool: + return path.startswith("/v1/") or path.startswith("/ai/") + + app = FastAPI(title="Skinbase Vision Gateway", version="1.0.0", lifespan=lifespan) app.add_middleware(APIKeyMiddleware) +@app.exception_handler(LLMGatewayError) +async def handle_llm_gateway_error(_: Request, exc: LLMGatewayError): + error: Dict[str, Any] = {"code": exc.code, "message": exc.message} + if exc.details is not None: + error["details"] = exc.details + return JSONResponse(status_code=exc.status_code, content={"error": error}) + + class ClipRequest(BaseModel): url: Optional[str] = None limit: int = Field(default=5, ge=1, le=50) @@ -118,6 +173,219 @@ class MaturityRequest(BaseModel): url: Optional[str] = None +class ChatMessage(BaseModel): + role: Literal["system", "user", "assistant"] + content: str + + @field_validator("content") + @classmethod + def validate_content(cls, value: str) -> str: + if not value or not value.strip(): + raise ValueError("message content must not be empty") + return value + + +class ChatCompletionRequest(BaseModel): + model: Optional[str] = None + messages: List[ChatMessage] = Field(min_length=1, max_length=100) + temperature: Optional[float] = None + max_tokens: Optional[int] = Field(default=None, ge=1) + stream: bool = False + top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) + stop: Optional[str | List[str]] = None + presence_penalty: Optional[float] = Field(default=None, ge=-2.0, le=2.0) + frequency_penalty: Optional[float] = Field(default=None, ge=-2.0, le=2.0) + + @field_validator("model") + @classmethod + def validate_model(cls, value: Optional[str]) -> Optional[str]: + if value is None: + return value + model = value.strip() + if not model: + raise ValueError("model must not be empty") + return model + + @field_validator("temperature") + @classmethod + def validate_temperature(cls, value: Optional[float]) -> Optional[float]: + if value is None: + return value + if value < 0.0 or value > 2.0: + raise ValueError("temperature must be between 0 and 2") + return value + + +def _llm_timeout() -> httpx.Timeout: + return httpx.Timeout(LLM_TIMEOUT, connect=min(LLM_TIMEOUT, 10)) + + +def _assert_llm_enabled() -> None: + if not LLM_ENABLED: + raise LLMGatewayError(503, "llm_disabled", "LLM service is disabled") + + +def _extract_upstream_error_message(response: httpx.Response) -> str: + try: + payload = response.json() + except Exception: + payload = None + + if isinstance(payload, dict): + error = payload.get("error") + if isinstance(error, dict) and error.get("message"): + return str(error["message"]) + if payload.get("message"): + return str(payload["message"]) + if payload.get("detail"): + return str(payload["detail"]) + + text = response.text.strip() + return text[:500] if text else f"Upstream returned HTTP {response.status_code}" + + +def _map_upstream_llm_status(status_code: int) -> int: + if status_code in (400, 413, 422): + return status_code + if 400 <= status_code < 500: + return 422 + return 503 + + +def _normalize_chat_payload(payload: ChatCompletionRequest) -> Dict[str, Any]: + normalized = payload.model_dump(exclude_none=True) + normalized["model"] = normalized.get("model") or LLM_DEFAULT_MODEL + normalized["max_tokens"] = min( + int(normalized.get("max_tokens") or LLM_MAX_TOKENS_DEFAULT), + LLM_MAX_TOKENS_HARD_LIMIT, + ) + + if "temperature" in normalized: + normalized["temperature"] = max(0.0, min(2.0, float(normalized["temperature"]))) + + if normalized.get("stream"): + raise LLMGatewayError( + 422, + "streaming_not_supported", + "Streaming responses are not enabled for this gateway", + ) + + return normalized + + +async def _parse_llm_request(request: Request) -> ChatCompletionRequest: + content_length = request.headers.get("content-length") + if content_length: + try: + if int(content_length) > LLM_MAX_REQUEST_BYTES: + raise LLMGatewayError( + 413, + "payload_too_large", + f"Request exceeds {LLM_MAX_REQUEST_BYTES} bytes", + ) + except ValueError: + raise LLMGatewayError(400, "invalid_request", "Invalid Content-Length header") + + body = await request.body() + if not body: + raise LLMGatewayError(400, "invalid_request", "Request body is required") + if len(body) > LLM_MAX_REQUEST_BYTES: + raise LLMGatewayError( + 413, + "payload_too_large", + f"Request exceeds {LLM_MAX_REQUEST_BYTES} bytes", + ) + + try: + payload = json.loads(body) + except json.JSONDecodeError: + raise LLMGatewayError(400, "invalid_json", "Request body must be valid JSON") + + if not isinstance(payload, dict): + raise LLMGatewayError(400, "invalid_request", "JSON body must be an object") + + try: + return ChatCompletionRequest.model_validate(payload) + except ValidationError as exc: + raise LLMGatewayError(422, "validation_error", "Invalid chat request", exc.errors()) + + +async def _llm_request( + method: str, + path: str, + *, + json_payload: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + _assert_llm_enabled() + + url = f"{LLM_URL}{path}" + try: + response = await get_http_client().request( + method, + url, + json=json_payload, + timeout=_llm_timeout(), + ) + except httpx.TimeoutException: + raise LLMGatewayError(504, "llm_timeout", "LLM request timed out") + except httpx.RequestError as exc: + raise LLMGatewayError(503, "llm_unavailable", f"LLM service is unavailable: {exc}") + + if response.status_code >= 500: + raise LLMGatewayError(503, "llm_unavailable", _extract_upstream_error_message(response)) + if response.status_code >= 400: + raise LLMGatewayError( + _map_upstream_llm_status(response.status_code), + "llm_rejected_request", + _extract_upstream_error_message(response), + ) + + try: + return response.json() + except Exception: + raise LLMGatewayError(503, "llm_invalid_response", "LLM service returned invalid JSON") + + +def _normalize_ai_chat_response(response: Dict[str, Any]) -> Dict[str, Any]: + choices = response.get("choices") + if not isinstance(choices, list) or not choices: + raise LLMGatewayError(503, "llm_invalid_response", "LLM response did not contain choices") + + first_choice = choices[0] if isinstance(choices[0], dict) else {} + message = first_choice.get("message") if isinstance(first_choice.get("message"), dict) else {} + content = message.get("content") + if not isinstance(content, str): + raise LLMGatewayError(503, "llm_invalid_response", "LLM response did not contain message content") + + usage = response.get("usage") if isinstance(response.get("usage"), dict) else {} + return { + "model": response.get("model") or LLM_DEFAULT_MODEL, + "content": content, + "finish_reason": first_choice.get("finish_reason") or "stop", + "usage": { + "prompt_tokens": int(usage.get("prompt_tokens") or 0), + "completion_tokens": int(usage.get("completion_tokens") or 0), + "total_tokens": int(usage.get("total_tokens") or 0), + }, + } + + +async def _get_llm_models_payload() -> Dict[str, Any]: + models = await _llm_request("GET", "/v1/models") + if isinstance(models.get("data"), list) and models["data"]: + return models + return { + "object": "list", + "data": [ + { + "id": LLM_DEFAULT_MODEL, + "object": "model", + "owned_by": "self-hosted", + } + ], + } + + async def _get_health(base: str) -> Dict[str, Any]: try: r = await get_http_client().get(f"{base}/health", timeout=5) @@ -184,8 +452,12 @@ async def health(): _get_health(YOLO_URL), _get_health(QDRANT_SVC_URL), ] + llm_index: Optional[int] = None if MATURITY_ENABLED: health_checks.append(_get_health(MATURITY_URL)) + if LLM_ENABLED: + llm_index = len(health_checks) + health_checks.append(_get_health(LLM_URL)) results = await asyncio.gather(*health_checks) services: Dict[str, Any] = { @@ -196,9 +468,71 @@ async def health(): } if MATURITY_ENABLED: services["maturity"] = results[4] + if LLM_ENABLED and llm_index is not None: + services["llm"] = { + "enabled": True, + "default_model": LLM_DEFAULT_MODEL, + "upstream": results[llm_index], + } + else: + services["llm"] = { + "enabled": False, + "default_model": LLM_DEFAULT_MODEL, + "upstream": {"status": "disabled"}, + } return {"status": "ok", "services": services} +@app.post("/v1/chat/completions") +async def llm_chat_completions(request: Request): + payload = _normalize_chat_payload(await _parse_llm_request(request)) + return await _llm_request("POST", "/v1/chat/completions", json_payload=payload) + + +@app.get("/v1/models") +async def llm_models(): + return await _get_llm_models_payload() + + +@app.post("/ai/chat") +async def ai_chat(request: Request): + payload = _normalize_chat_payload(await _parse_llm_request(request)) + response = await _llm_request("POST", "/v1/chat/completions", json_payload=payload) + return _normalize_ai_chat_response(response) + + +@app.get("/ai/models") +async def ai_models(): + models = await _get_llm_models_payload() + return { + "enabled": LLM_ENABLED, + "default_model": LLM_DEFAULT_MODEL, + "models": models.get("data", []), + } + + +@app.get("/ai/health") +async def ai_health(): + if not LLM_ENABLED: + return { + "status": "ok", + "enabled": False, + "reachable": False, + "default_model": LLM_DEFAULT_MODEL, + "upstream": {"status": "disabled"}, + } + + upstream = await _get_health(LLM_URL) + reachable = upstream.get("status") == "ok" + return { + "status": "ok" if reachable else "degraded", + "enabled": True, + "reachable": reachable, + "default_model": LLM_DEFAULT_MODEL, + "upstream": upstream, + } + + # ---- Individual analyze endpoints (URL) ---- @app.post("/analyze/clip") diff --git a/llm/Dockerfile b/llm/Dockerfile new file mode 100644 index 0000000..a2dbada --- /dev/null +++ b/llm/Dockerfile @@ -0,0 +1,53 @@ +FROM debian:bookworm-slim AS builder + +ARG LLAMA_CPP_REPO=https://github.com/ggml-org/llama.cpp.git +ARG LLAMA_CPP_REF= + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + ca-certificates \ + cmake \ + git \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /src +RUN git clone --depth 1 ${LLAMA_CPP_REPO} llama.cpp \ + && if [ -n "${LLAMA_CPP_REF}" ]; then cd llama.cpp && git fetch --depth 1 origin "${LLAMA_CPP_REF}" && git checkout "${LLAMA_CPP_REF}"; fi + +WORKDIR /src/llama.cpp +RUN cmake -B build -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON \ + && cmake --build build --config Release --target llama-server -j"$(nproc)" + +FROM python:3.11-slim + +RUN apt-get update && apt-get install -y --no-install-recommends \ + bash \ + ca-certificates \ + curl \ + libgomp1 \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY llm/requirements.txt /app/requirements.txt +RUN pip install --no-cache-dir -r /app/requirements.txt + +COPY --from=builder /src/llama.cpp/build/bin/llama-server /usr/local/bin/llama-server +COPY llm/main.py /app/main.py +COPY llm/entrypoint.sh /entrypoint.sh + +RUN chmod +x /entrypoint.sh /usr/local/bin/llama-server + +ENV MODEL_PATH=/models/Qwen3-1.7B-Instruct-Q4_K_M.gguf \ + LLM_MODEL_NAME=qwen3-1.7b-instruct-q4_k_m \ + LLM_CONTEXT_SIZE=4096 \ + LLM_THREADS=4 \ + LLM_GPU_LAYERS=0 \ + LLM_PORT=8080 \ + LLAMA_SERVER_PORT=8081 \ + LLM_STARTUP_TIMEOUT=120 \ + LLM_EXTRA_ARGS= + +EXPOSE 8080 + +ENTRYPOINT ["/entrypoint.sh"] \ No newline at end of file diff --git a/llm/entrypoint.sh b/llm/entrypoint.sh new file mode 100644 index 0000000..6923e67 --- /dev/null +++ b/llm/entrypoint.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +set -eu + +MODEL_PATH="${MODEL_PATH:-/models/Qwen3-1.7B-Instruct-Q4_K_M.gguf}" +LLM_MODEL_NAME="${LLM_MODEL_NAME:-qwen3-1.7b-instruct-q4_k_m}" +LLM_CONTEXT_SIZE="${LLM_CONTEXT_SIZE:-4096}" +LLM_THREADS="${LLM_THREADS:-4}" +LLM_GPU_LAYERS="${LLM_GPU_LAYERS:-0}" +LLM_PORT="${LLM_PORT:-8080}" +LLAMA_SERVER_PORT="${LLAMA_SERVER_PORT:-8081}" + +if [ ! -f "$MODEL_PATH" ]; then + echo "llm startup failed: model file not found at $MODEL_PATH" >&2 + echo "Mount a GGUF model into ./models/qwen3 and set MODEL_PATH if the filename differs." >&2 + exit 1 +fi + +if [ ! -r "$MODEL_PATH" ]; then + echo "llm startup failed: model file is not readable at $MODEL_PATH" >&2 + exit 1 +fi + +echo "Starting llm shim model=$LLM_MODEL_NAME model_path=$MODEL_PATH public_port=$LLM_PORT upstream_port=$LLAMA_SERVER_PORT ctx=$LLM_CONTEXT_SIZE threads=$LLM_THREADS gpu_layers=$LLM_GPU_LAYERS" + +exec python -m uvicorn main:app --host 0.0.0.0 --port "$LLM_PORT" \ No newline at end of file diff --git a/llm/main.py b/llm/main.py new file mode 100644 index 0000000..e63498a --- /dev/null +++ b/llm/main.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import asyncio +import logging +import os +import shlex +import subprocess +import time +from contextlib import asynccontextmanager +from pathlib import Path +from typing import Any, Dict, Optional + +import httpx +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse + +logger = logging.getLogger("llm") +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s") + +MODEL_PATH = os.getenv("MODEL_PATH", "/models/Qwen3-1.7B-Instruct-Q4_K_M.gguf") +LLM_MODEL_NAME = os.getenv("LLM_MODEL_NAME", "qwen3-1.7b-instruct-q4_k_m") +LLM_CONTEXT_SIZE = int(os.getenv("LLM_CONTEXT_SIZE", "4096")) +LLM_THREADS = int(os.getenv("LLM_THREADS", "4")) +LLM_GPU_LAYERS = int(os.getenv("LLM_GPU_LAYERS", "0")) +LLAMA_SERVER_PORT = int(os.getenv("LLAMA_SERVER_PORT", "8081")) +LLM_STARTUP_TIMEOUT = float(os.getenv("LLM_STARTUP_TIMEOUT", "120")) +LLM_EXTRA_ARGS = os.getenv("LLM_EXTRA_ARGS", "") + +_llama_process: subprocess.Popen[bytes] | None = None +_http_client: httpx.AsyncClient | None = None + + +def _upstream_base_url() -> str: + return f"http://127.0.0.1:{LLAMA_SERVER_PORT}" + + +def _ensure_http_client() -> httpx.AsyncClient: + if _http_client is None: + raise RuntimeError("HTTP client not initialised") + return _http_client + + +def _validate_model_path() -> None: + model_file = Path(MODEL_PATH) + if not model_file.is_file(): + raise RuntimeError(f"model file not found at {MODEL_PATH}") + if not os.access(model_file, os.R_OK): + raise RuntimeError(f"model file is not readable at {MODEL_PATH}") + + +def _build_llama_command() -> list[str]: + command = [ + "/usr/local/bin/llama-server", + "--host", + "127.0.0.1", + "--port", + str(LLAMA_SERVER_PORT), + "--model", + MODEL_PATH, + "--alias", + LLM_MODEL_NAME, + "--ctx-size", + str(LLM_CONTEXT_SIZE), + "--threads", + str(LLM_THREADS), + "--n-gpu-layers", + str(LLM_GPU_LAYERS), + ] + if LLM_EXTRA_ARGS.strip(): + command.extend(shlex.split(LLM_EXTRA_ARGS)) + return command + + +def _llama_running() -> bool: + return _llama_process is not None and _llama_process.poll() is None + + +async def _wait_for_llama_ready() -> None: + deadline = time.monotonic() + LLM_STARTUP_TIMEOUT + last_error: Optional[Exception] = None + + while time.monotonic() < deadline: + if _llama_process is not None and _llama_process.poll() is not None: + raise RuntimeError(f"llama-server exited with code {_llama_process.poll()}") + + try: + response = await _ensure_http_client().get(f"{_upstream_base_url()}/v1/models", timeout=5) + if response.status_code == 200: + logger.info("llm service: llama-server ready") + return + except Exception as exc: + last_error = exc + + await asyncio.sleep(1) + + raise RuntimeError(f"llama-server did not become ready within {LLM_STARTUP_TIMEOUT}s: {last_error}") + + +async def _stop_llama_process() -> None: + global _llama_process + + if _llama_process is None: + return + + if _llama_process.poll() is None: + _llama_process.terminate() + try: + await asyncio.to_thread(_llama_process.wait, timeout=10) + except subprocess.TimeoutExpired: + _llama_process.kill() + await asyncio.to_thread(_llama_process.wait, timeout=5) + + _llama_process = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + global _http_client, _llama_process + + _validate_model_path() + _http_client = httpx.AsyncClient(timeout=httpx.Timeout(120, connect=5)) + + command = _build_llama_command() + logger.info("llm service: starting llama-server model=%s ctx=%s threads=%s gpu_layers=%s upstream_port=%s", LLM_MODEL_NAME, LLM_CONTEXT_SIZE, LLM_THREADS, LLM_GPU_LAYERS, LLAMA_SERVER_PORT) + _llama_process = subprocess.Popen(command) + + try: + await _wait_for_llama_ready() + yield + finally: + await _stop_llama_process() + if _http_client is not None: + await _http_client.aclose() + _http_client = None + + +app = FastAPI(title="Skinbase LLM Service", version="1.0.0", lifespan=lifespan) + + +def _health_payload(status: str) -> Dict[str, Any]: + return { + "status": status, + "model": Path(MODEL_PATH).name, + "model_alias": LLM_MODEL_NAME, + "context_size": LLM_CONTEXT_SIZE, + "threads": LLM_THREADS, + "gpu_layers": LLM_GPU_LAYERS, + } + + +async def _proxy_request(method: str, path: str, *, body: bytes | None = None) -> Dict[str, Any]: + if not _llama_running(): + raise HTTPException(status_code=503, detail="llama-server is not running") + + headers = {"content-type": "application/json"} if body is not None else None + try: + response = await _ensure_http_client().request( + method, + f"{_upstream_base_url()}{path}", + content=body, + headers=headers, + timeout=httpx.Timeout(120, connect=5), + ) + except httpx.TimeoutException as exc: + raise HTTPException(status_code=504, detail=f"llama-server timed out: {exc}") + except httpx.RequestError as exc: + raise HTTPException(status_code=503, detail=f"llama-server unavailable: {exc}") + + if response.status_code >= 400: + detail: Any + try: + detail = response.json() + except Exception: + detail = response.text[:1000] + raise HTTPException(status_code=response.status_code, detail=detail) + + try: + return response.json() + except Exception as exc: + raise HTTPException(status_code=502, detail=f"llama-server returned invalid JSON: {exc}") + + +@app.exception_handler(HTTPException) +async def handle_http_exception(_: Request, exc: HTTPException): + return JSONResponse(status_code=exc.status_code, content={"error": {"code": "llm_service_error", "message": str(exc.detail)}}) + + +@app.get("/health") +async def health(): + if not _llama_running(): + return JSONResponse(status_code=503, content=_health_payload("unavailable")) + + try: + response = await _ensure_http_client().get(f"{_upstream_base_url()}/v1/models", timeout=5) + if response.status_code != 200: + return JSONResponse(status_code=503, content=_health_payload("degraded")) + except Exception: + return JSONResponse(status_code=503, content=_health_payload("degraded")) + + return _health_payload("ok") + + +@app.get("/v1/models") +async def list_models(): + return await _proxy_request("GET", "/v1/models") + + +@app.post("/v1/chat/completions") +async def chat_completions(request: Request): + body = await request.body() + return await _proxy_request("POST", "/v1/chat/completions", body=body) \ No newline at end of file diff --git a/llm/requirements.txt b/llm/requirements.txt new file mode 100644 index 0000000..94dd638 --- /dev/null +++ b/llm/requirements.txt @@ -0,0 +1,3 @@ +fastapi==0.115.5 +uvicorn[standard]==0.30.6 +httpx==0.27.2 \ No newline at end of file diff --git a/models/qwen3/README.md b/models/qwen3/README.md new file mode 100644 index 0000000..ad05586 --- /dev/null +++ b/models/qwen3/README.md @@ -0,0 +1,9 @@ +Place the Qwen3 GGUF model file for the local llm profile in this directory. + +Expected default filename: + +- `Qwen3-1.7B-Instruct-Q4_K_M.gguf` + +You can use a different filename, but then set `MODEL_PATH` in `.env` to match the mounted path inside the container. + +The model is intentionally not auto-downloaded at startup. Operators should provision it explicitly so container startup is predictable. \ No newline at end of file diff --git a/qdrant/backfill_payloads.py b/qdrant/backfill_payloads.py new file mode 100644 index 0000000..641da92 --- /dev/null +++ b/qdrant/backfill_payloads.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +""" +backfill_payloads.py — Repair missing payload fields for existing Qdrant points. + +WHY THIS EXISTS +--------------- +If artworks were initially upserted without the full payload (is_public, is_nsfw, +category_id, content_type_id, is_deleted, status), those fields will have near-0% +coverage in the payload index. This prevents filtered searches (e.g., is_public=true) +from returning correct results. + +This script scrolls through all points in the collection, detects which ones are +missing the required fields, and lets you supply a lookup function that fetches the +correct values from your source-of-truth (database, API, CSV, etc.). + +HOW TO ADAPT +------------ +1. Fill in `fetch_payloads_for_ids()` to return a dict mapping qdrant-point-id -> + payload patch for each missing ID. The simplest approach is a SQL query to your + Skinbase database using the `_original_id` stored in the Qdrant payload. + +2. Run the script directly (no app container needed, just qdrant-client installed): + + # Inside Docker network: + docker exec -it vision-qdrant-svc-1 python /app/backfill_payloads.py + + # Or from host with qdrant-client installed: + pip install qdrant-client + QDRANT_HOST=localhost QDRANT_PORT=6333 python qdrant/backfill_payloads.py + +3. The script is resumable: it prints the last-processed offset ID so you can + restart from where you left off by setting RESUME_OFFSET env var. + +REQUIRED ENV VARS (all optional, sensible defaults for Docker Compose): + QDRANT_HOST default: qdrant + QDRANT_PORT default: 6333 + COLLECTION_NAME default: images + BATCH_SIZE default: 256 + DRY_RUN default: 0 (set to 1 to only report, no writes) + RESUME_OFFSET default: None (UUID or int of last seen point to skip to) + +FIELDS CHECKED +-------------- + user_id, is_public, is_nsfw, category_id, content_type_id, is_deleted, status +""" + +from __future__ import annotations + +import os +import sys +import time +import logging +from typing import Any, Dict, List, Optional + +from qdrant_client import QdrantClient +from qdrant_client.models import PointStruct + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +log = logging.getLogger("backfill") + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant") +QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333")) +COLLECTION_NAME = os.getenv("COLLECTION_NAME", "images") +BATCH_SIZE = int(os.getenv("BATCH_SIZE", "256")) +DRY_RUN = os.getenv("DRY_RUN", "0") == "1" +RESUME_OFFSET: Optional[str] = os.getenv("RESUME_OFFSET") # point id to continue from + +# Fields that MUST be present in every point payload for filtered search to work. +REQUIRED_FIELDS = [ + "user_id", + "is_public", + "is_nsfw", + "category_id", + "content_type_id", + "is_deleted", + "status", +] + + +# --------------------------------------------------------------------------- +# TODO: implement this function to fetch correct payload values from your DB. +# --------------------------------------------------------------------------- + +def fetch_payloads_for_ids( + missing_ids: List[Any], + original_ids: Dict[Any, str], +) -> Dict[Any, Dict[str, Any]]: + """Return a mapping of qdrant_point_id -> payload_patch for the given IDs. + + Parameters + ---------- + missing_ids: + List of Qdrant point IDs (UUID strings or ints) that need patching. + original_ids: + Dict mapping qdrant_point_id -> original application ID (stored in + `_original_id` payload field, or the point id itself if they match). + + Returns + ------- + Dict mapping each point id to a dict of fields to set. + Only include the fields you want to SET — existing fields are not cleared. + + Example implementation (pseudo-code for your database): + + import psycopg2 + conn = psycopg2.connect(os.environ["DATABASE_URL"]) + cur = conn.cursor() + orig_id_list = list(original_ids.values()) + cur.execute( + "SELECT id, user_id, is_public, is_nsfw, category_id, " + " content_type_id, is_deleted, status " + "FROM artworks WHERE id = ANY(%s)", + (orig_id_list,) + ) + rows = cur.fetchall() + by_orig = {str(r[0]): r for r in rows} + result = {} + for qdrant_id, orig_id in original_ids.items(): + row = by_orig.get(str(orig_id)) + if row: + result[qdrant_id] = { + "user_id": str(row[1]), + "is_public": bool(row[2]), + "is_nsfw": bool(row[3]), + "category_id": int(row[4]) if row[4] is not None else None, + "content_type_id": int(row[5]) if row[5] is not None else None, + "is_deleted": bool(row[6]), + "status": str(row[7]), + } + return result + """ + # ---- STUB: replace with your real implementation ---- + log.warning( + "fetch_payloads_for_ids() is a stub — no data will be patched.\n" + "Edit qdrant/backfill_payloads.py and implement this function." + ) + return {} + + +# --------------------------------------------------------------------------- +# Core backfill logic +# --------------------------------------------------------------------------- + +def run_backfill(): + log.info( + "backfill start collection=%s host=%s:%s dry_run=%s batch=%d", + COLLECTION_NAME, QDRANT_HOST, QDRANT_PORT, DRY_RUN, BATCH_SIZE, + ) + + qclient = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT) + + # Verify collection exists + collections = [c.name for c in qclient.get_collections().collections] + if COLLECTION_NAME not in collections: + log.error("Collection '%s' not found. Existing: %s", COLLECTION_NAME, collections) + sys.exit(1) + + info = qclient.get_collection(COLLECTION_NAME) + total_points = info.points_count or 0 + log.info("collection points_count=%d indexed_vectors=%d", total_points, info.indexed_vectors_count or 0) + + offset = RESUME_OFFSET + scanned = 0 + missing_count = 0 + patched = 0 + errors = 0 + t_start = time.perf_counter() + + while True: + points, next_offset = qclient.scroll( + collection_name=COLLECTION_NAME, + offset=offset, + limit=BATCH_SIZE, + with_payload=True, + with_vectors=False, + ) + + if not points: + break + + scanned += len(points) + + # Find points missing any required field + needs_patch: List[Any] = [] + original_ids: Dict[Any, str] = {} + for pt in points: + payload = pt.payload or {} + missing = [f for f in REQUIRED_FIELDS if f not in payload or payload[f] is None] + if missing: + needs_patch.append(pt.id) + # Use _original_id if present (IDs that couldn't be stored as Qdrant IDs) + original_ids[pt.id] = str(payload.get("_original_id", pt.id)) + missing_count += 1 + + if needs_patch: + patches = fetch_payloads_for_ids(needs_patch, original_ids) + for pid, patch in patches.items(): + if not patch: + continue + if DRY_RUN: + log.info("[DRY RUN] would patch id=%s fields=%s", pid, list(patch.keys())) + else: + try: + qclient.set_payload( + collection_name=COLLECTION_NAME, + payload=patch, + points=[pid], + ) + patched += 1 + except Exception as exc: + log.error("failed to patch id=%s: %s", pid, exc) + errors += 1 + + elapsed = time.perf_counter() - t_start + rate = scanned / elapsed if elapsed > 0 else 0 + log.info( + "progress scanned=%d/%d missing=%d patched=%d errors=%d rate=%.0f/s offset=%s", + scanned, total_points, missing_count, patched, errors, rate, next_offset, + ) + + if next_offset is None: + break + offset = next_offset + + elapsed = time.perf_counter() - t_start + log.info( + "backfill complete scanned=%d missing=%d patched=%d errors=%d elapsed=%.1fs", + scanned, missing_count, patched, errors, elapsed, + ) + + if missing_count > 0 and patched == 0 and not DRY_RUN: + log.warning( + "%d points are missing payload fields but 0 were patched. " + "Implement fetch_payloads_for_ids() in this script.", + missing_count, + ) + + +if __name__ == "__main__": + run_backfill() diff --git a/qdrant/main.py b/qdrant/main.py index 43e2f19..3928eba 100644 --- a/qdrant/main.py +++ b/qdrant/main.py @@ -1,6 +1,9 @@ from __future__ import annotations +import asyncio +import logging import os +import time import uuid from typing import Any, Dict, List, Optional @@ -39,6 +42,9 @@ SEARCH_HNSW_EF = int(os.getenv("SEARCH_HNSW_EF", "128")) app = FastAPI(title="Skinbase Qdrant Service", version="1.0.0") client: QdrantClient = None # type: ignore[assignment] +logger = logging.getLogger("qdrant_svc") +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s") + # --------------------------------------------------------------------------- # Startup / shutdown @@ -47,8 +53,24 @@ client: QdrantClient = None # type: ignore[assignment] @app.on_event("startup") def startup(): global client + t0 = time.perf_counter() + logger.info("qdrant_svc startup: connecting to %s:%s", QDRANT_HOST, QDRANT_PORT) client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT) _ensure_collection() + # Warm the gRPC/HTTP connection and load collection metadata into memory + # so the first real request does not pay the one-time connect cost. + try: + info = client.get_collection(COLLECTION_NAME) + logger.info( + "qdrant_svc startup: warm ping OK collection=%s points=%s indexed=%s elapsed_ms=%.1f", + COLLECTION_NAME, + info.points_count, + info.indexed_vectors_count, + (time.perf_counter() - t0) * 1000, + ) + except Exception as exc: + logger.warning("qdrant_svc startup: warm ping failed (non-fatal): %s", exc) + logger.info("qdrant_svc startup complete elapsed_ms=%.1f", (time.perf_counter() - t0) * 1000) def _ensure_collection(): @@ -68,6 +90,44 @@ def _ensure_collection(): default_segment_number=4, # parallelism-friendly segment count ), ) + _ensure_payload_indexes() + + +# Payload fields needed for filtered search. type values match PayloadSchemaType. +_REQUIRED_PAYLOAD_INDEXES: List[Dict[str, str]] = [ + {"field": "user_id", "type": "keyword"}, + {"field": "is_public", "type": "bool"}, + {"field": "is_nsfw", "type": "bool"}, + {"field": "is_deleted", "type": "bool"}, + {"field": "status", "type": "keyword"}, + {"field": "category_id", "type": "integer"}, + {"field": "content_type_id", "type": "integer"}, +] + + +def _ensure_payload_indexes(): + """Create any missing payload indexes for the default collection.""" + try: + info = client.get_collection(COLLECTION_NAME) + except Exception: + return # collection doesn't exist yet, will be created next + existing = set(info.payload_schema.keys()) if info.payload_schema else set() + for spec in _REQUIRED_PAYLOAD_INDEXES: + field = spec["field"] + if field in existing: + continue + schema = _SCHEMA_TYPE_MAP.get(spec["type"]) + if schema is None: + continue + try: + client.create_payload_index( + collection_name=COLLECTION_NAME, + field_name=field, + field_schema=schema, + ) + logger.info("_ensure_payload_indexes: created index field=%s type=%s", field, spec["type"]) + except Exception as exc: + logger.warning("_ensure_payload_indexes: could not index field=%s: %s", field, exc) # --------------------------------------------------------------------------- @@ -213,23 +273,31 @@ def health(): @app.get("/inspect") -def inspect(): +async def inspect(): """Return a full diagnostic summary for every collection. Covers: vector counts, segment counts, HNSW config, optimizer config, quantization, payload indexes and their coverage. Designed for production health checks and the Qdrant optimization workflow. """ + t0 = time.perf_counter() + logger.info("inspect: start") + try: - all_collections = client.get_collections().collections + all_collections = await asyncio.get_event_loop().run_in_executor( + None, lambda: client.get_collections().collections + ) except Exception as exc: return {"status": "error", "detail": str(exc)} result = {} for col_desc in all_collections: name = col_desc.name + t_col = time.perf_counter() try: - info = client.get_collection(name) + info = await asyncio.get_event_loop().run_in_executor( + None, lambda n=name: client.get_collection(n) + ) cfg = info.config hnsw = cfg.hnsw_config opt = cfg.optimizer_config @@ -281,9 +349,14 @@ def inspect(): "payload_index_count": len(info.payload_schema or {}), "search_hnsw_ef": SEARCH_HNSW_EF, } + logger.info( + "inspect: collection=%s points=%s elapsed_ms=%.1f", + name, points_count, (time.perf_counter() - t_col) * 1000, + ) except Exception as exc: result[name] = {"error": str(exc)} + logger.info("inspect: done collections=%d total_elapsed_ms=%.1f", len(result), (time.perf_counter() - t0) * 1000) return {"collections": result, "total": len(result)} @@ -757,3 +830,54 @@ def configure_collection(name: str, req: CollectionConfigRequest): } except Exception as exc: raise HTTPException(500, str(exc)) + + +# --------------------------------------------------------------------------- +# Payload update (used by backfill / repair tooling) +# --------------------------------------------------------------------------- + +class BatchUpdatePayloadRequest(BaseModel): + """Update payload fields for a batch of points identified by their Qdrant IDs. + + ``updates`` is a list of ``{"id": "", "payload": {...}}`` items. + Only the supplied payload keys are merged into existing payloads (set_payload + semantics — existing keys not mentioned are left untouched). + """ + updates: List[Dict[str, Any]] + collection: Optional[str] = None + + +@app.post("/points/batch-update-payload") +def batch_update_payload(req: BatchUpdatePayloadRequest): + """Merge payload fields for a list of points without touching vectors. + + Useful for backfilling metadata (is_public, category_id, etc.) for points + that were upserted without full payload coverage. + """ + if not req.updates: + return {"updated": 0, "collection": _col(req.collection)} + + col = _col(req.collection) + updated = 0 + errors: List[str] = [] + + for item in req.updates: + pid_raw = item.get("id") + payload = item.get("payload", {}) + if pid_raw is None or not payload: + continue + pid = _point_id(str(pid_raw)) + try: + client.set_payload( + collection_name=col, + payload=payload, + points=[pid], + ) + updated += 1 + except Exception as exc: + errors.append(f"{pid_raw}: {exc}") + + result: Dict[str, Any] = {"updated": updated, "collection": col} + if errors: + result["errors"] = errors + return result diff --git a/tests/test_gateway_llm.py b/tests/test_gateway_llm.py new file mode 100644 index 0000000..5452edb --- /dev/null +++ b/tests/test_gateway_llm.py @@ -0,0 +1,313 @@ +from __future__ import annotations + +import importlib +import os +import unittest +from typing import Any, Dict, Optional +from unittest.mock import patch + +import httpx + + +BASE_ENV = { + "API_KEY": "test-key", + "CLIP_URL": "http://clip:8000", + "BLIP_URL": "http://blip:8000", + "YOLO_URL": "http://yolo:8000", + "QDRANT_SVC_URL": "http://qdrant-svc:8000", + "CARD_RENDERER_URL": "http://card-renderer:8000", + "MATURITY_URL": "http://maturity:8000", + "LLM_URL": "http://llm:8080", + "LLM_TIMEOUT": "5", + "LLM_DEFAULT_MODEL": "qwen3-1.7b-instruct-q4_k_m", + "LLM_MAX_TOKENS_DEFAULT": "256", + "LLM_MAX_TOKENS_HARD_LIMIT": "1024", + "LLM_MAX_REQUEST_BYTES": "65536", +} + + +def load_gateway_module(*, llm_enabled: bool, extra_env: Optional[Dict[str, str]] = None): + env = BASE_ENV | {"LLM_ENABLED": "true" if llm_enabled else "false"} + if extra_env: + env |= extra_env + with patch.dict(os.environ, env, clear=False): + import gateway.main as gateway_main + + return importlib.reload(gateway_main) + + +class StubUpstreamClient: + def __init__( + self, + *, + request_responses: Optional[Dict[tuple[str, str], httpx.Response]] = None, + get_responses: Optional[Dict[str, httpx.Response]] = None, + request_exception: Optional[Exception] = None, + get_exception: Optional[Exception] = None, + ): + self.request_responses = request_responses or {} + self.get_responses = get_responses or {} + self.request_exception = request_exception + self.get_exception = get_exception + + async def request(self, method: str, url: str, **_: Any) -> httpx.Response: + if self.request_exception is not None: + raise self.request_exception + response = self.request_responses.get((method.upper(), url)) + if response is None: + return httpx.Response(404, json={"error": {"message": f"No stub for {method} {url}"}}) + return response + + async def get(self, url: str, **_: Any) -> httpx.Response: + if self.get_exception is not None: + raise self.get_exception + response = self.get_responses.get(url) + if response is None: + return httpx.Response(404, json={"detail": f"No stub for GET {url}"}) + return response + + +class GatewayLLMTests(unittest.IsolatedAsyncioTestCase): + async def _request( + self, + module: Any, + method: str, + path: str, + *, + headers: Optional[Dict[str, str]] = None, + json_payload: Optional[Dict[str, Any]] = None, + content: Optional[bytes] = None, + ) -> httpx.Response: + transport = httpx.ASGITransport(app=module.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + return await client.request(method, path, headers=headers, json=json_payload, content=content) + + async def test_llm_endpoint_requires_api_key(self): + module = load_gateway_module(llm_enabled=True) + + response = await self._request( + module, + "POST", + "/ai/chat", + json_payload={"messages": [{"role": "user", "content": "hello"}]}, + ) + + self.assertEqual(response.status_code, 401) + self.assertEqual(response.json()["error"]["code"], "unauthorized") + + async def test_llm_disabled_returns_503(self): + module = load_gateway_module(llm_enabled=False) + + response = await self._request( + module, + "POST", + "/ai/chat", + headers={"X-API-Key": "test-key"}, + json_payload={"messages": [{"role": "user", "content": "hello"}]}, + ) + + self.assertEqual(response.status_code, 503) + self.assertEqual(response.json()["error"]["code"], "llm_disabled") + + async def test_unreachable_llm_returns_normalized_503(self): + module = load_gateway_module(llm_enabled=True) + stub_client = StubUpstreamClient( + request_exception=httpx.ConnectError("boom", request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions")), + ) + + with patch.object(module, "get_http_client", return_value=stub_client): + response = await self._request( + module, + "POST", + "/ai/chat", + headers={"X-API-Key": "test-key"}, + json_payload={"messages": [{"role": "user", "content": "hello"}]}, + ) + + self.assertEqual(response.status_code, 503) + self.assertEqual(response.json()["error"]["code"], "llm_unavailable") + + async def test_validation_error_is_normalized(self): + module = load_gateway_module(llm_enabled=True) + + response = await self._request( + module, + "POST", + "/ai/chat", + headers={"X-API-Key": "test-key"}, + json_payload={"messages": []}, + ) + + self.assertEqual(response.status_code, 422) + self.assertEqual(response.json()["error"]["code"], "validation_error") + + async def test_invalid_json_returns_400(self): + module = load_gateway_module(llm_enabled=True) + + response = await self._request( + module, + "POST", + "/v1/chat/completions", + headers={"X-API-Key": "test-key", "Content-Type": "application/json"}, + content=b'{"messages": [', + ) + + self.assertEqual(response.status_code, 400) + self.assertEqual(response.json()["error"]["code"], "invalid_json") + + async def test_oversized_payload_returns_413(self): + module = load_gateway_module(llm_enabled=True, extra_env={"LLM_MAX_REQUEST_BYTES": "64"}) + + response = await self._request( + module, + "POST", + "/v1/chat/completions", + headers={"X-API-Key": "test-key"}, + json_payload={"messages": [{"role": "user", "content": "x" * 5000}]}, + ) + + self.assertEqual(response.status_code, 413) + self.assertEqual(response.json()["error"]["code"], "payload_too_large") + + async def test_ai_chat_normalizes_successful_response(self): + module = load_gateway_module(llm_enabled=True) + upstream_response = httpx.Response( + 200, + json={ + "id": "chatcmpl-1", + "object": "chat.completion", + "model": "qwen3-1.7b-instruct-q4_k_m", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": {"role": "assistant", "content": "Generated text here."}, + } + ], + "usage": {"prompt_tokens": 12, "completion_tokens": 8, "total_tokens": 20}, + }, + request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions"), + ) + stub_client = StubUpstreamClient( + request_responses={("POST", f"{module.LLM_URL}/v1/chat/completions"): upstream_response}, + ) + + with patch.object(module, "get_http_client", return_value=stub_client): + response = await self._request( + module, + "POST", + "/ai/chat", + headers={"X-API-Key": "test-key"}, + json_payload={"messages": [{"role": "user", "content": "hello"}]}, + ) + + self.assertEqual(response.status_code, 200) + self.assertEqual( + response.json(), + { + "model": "qwen3-1.7b-instruct-q4_k_m", + "content": "Generated text here.", + "finish_reason": "stop", + "usage": {"prompt_tokens": 12, "completion_tokens": 8, "total_tokens": 20}, + }, + ) + + async def test_ai_health_reports_reachable_llm(self): + module = load_gateway_module(llm_enabled=True) + stub_client = StubUpstreamClient( + get_responses={ + f"{module.LLM_URL}/health": httpx.Response( + 200, + json={"status": "ok", "model": "Qwen3-1.7B-Instruct-Q4_K_M.gguf", "context_size": 4096, "threads": 4}, + request=httpx.Request("GET", f"{module.LLM_URL}/health"), + ) + }, + ) + + with patch.object(module, "get_http_client", return_value=stub_client): + response = await self._request( + module, + "GET", + "/ai/health", + headers={"X-API-Key": "test-key"}, + ) + + self.assertEqual(response.status_code, 200) + self.assertTrue(response.json()["reachable"]) + self.assertEqual(response.json()["default_model"], "qwen3-1.7b-instruct-q4_k_m") + + async def test_timeout_returns_504(self): + module = load_gateway_module(llm_enabled=True) + stub_client = StubUpstreamClient( + request_exception=httpx.ReadTimeout("timeout", request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions")), + ) + + with patch.object(module, "get_http_client", return_value=stub_client): + response = await self._request( + module, + "POST", + "/ai/chat", + headers={"X-API-Key": "test-key"}, + json_payload={"messages": [{"role": "user", "content": "hello"}]}, + ) + + self.assertEqual(response.status_code, 504) + self.assertEqual(response.json()["error"]["code"], "llm_timeout") + + async def test_upstream_400_is_preserved(self): + module = load_gateway_module(llm_enabled=True) + bad_request_response = httpx.Response( + 400, + json={"error": {"message": "Bad prompt"}}, + request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions"), + ) + stub_client = StubUpstreamClient( + request_responses={("POST", f"{module.LLM_URL}/v1/chat/completions"): bad_request_response}, + ) + + with patch.object(module, "get_http_client", return_value=stub_client): + response = await self._request( + module, + "POST", + "/v1/chat/completions", + headers={"X-API-Key": "test-key"}, + json_payload={"messages": [{"role": "user", "content": "hello"}]}, + ) + + self.assertEqual(response.status_code, 400) + self.assertEqual(response.json()["error"]["code"], "llm_rejected_request") + + async def test_models_endpoint_returns_upstream_metadata(self): + module = load_gateway_module(llm_enabled=True) + models_response = httpx.Response( + 200, + json={ + "object": "list", + "data": [ + { + "id": "qwen3-1.7b-instruct-q4_k_m", + "object": "model", + "owned_by": "self-hosted", + } + ], + }, + request=httpx.Request("GET", f"{module.LLM_URL}/v1/models"), + ) + stub_client = StubUpstreamClient( + request_responses={("GET", f"{module.LLM_URL}/v1/models"): models_response}, + ) + + with patch.object(module, "get_http_client", return_value=stub_client): + response = await self._request( + module, + "GET", + "/v1/models", + headers={"X-API-Key": "test-key"}, + ) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json()["data"][0]["id"], "qwen3-1.7b-instruct-q4_k_m") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/test_llm_service.py b/tests/test_llm_service.py new file mode 100644 index 0000000..e7c906b --- /dev/null +++ b/tests/test_llm_service.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import importlib +import os +import unittest +from types import SimpleNamespace +from unittest.mock import patch + +import httpx + + +BASE_ENV = { + "MODEL_PATH": "D:/Sites/vision/models/qwen3/Qwen3-1.7B-Instruct-Q4_K_M.gguf", + "LLM_MODEL_NAME": "qwen3-1.7b-instruct-q4_k_m", + "LLM_CONTEXT_SIZE": "4096", + "LLM_THREADS": "4", + "LLM_GPU_LAYERS": "0", + "LLM_PORT": "8080", + "LLAMA_SERVER_PORT": "8081", +} + + +def load_llm_module(): + with patch.dict(os.environ, BASE_ENV, clear=False): + import llm.main as llm_main + + return importlib.reload(llm_main) + + +class StubHTTPClient: + def __init__(self, response: httpx.Response): + self.response = response + + async def get(self, *_args, **_kwargs): + return self.response + + +class LLMServiceTests(unittest.IsolatedAsyncioTestCase): + async def test_health_returns_repo_owned_contract(self): + module = load_llm_module() + module._llama_process = SimpleNamespace(poll=lambda: None) + module._http_client = StubHTTPClient( + httpx.Response(200, json={"object": "list", "data": []}, request=httpx.Request("GET", "http://127.0.0.1:8081/v1/models")) + ) + + transport = httpx.ASGITransport(app=module.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + response = await client.get("/health") + + self.assertEqual(response.status_code, 200) + self.assertEqual( + response.json(), + { + "status": "ok", + "model": "Qwen3-1.7B-Instruct-Q4_K_M.gguf", + "model_alias": "qwen3-1.7b-instruct-q4_k_m", + "context_size": 4096, + "threads": 4, + "gpu_layers": 0, + }, + ) + + async def test_health_reports_unavailable_when_process_is_down(self): + module = load_llm_module() + module._llama_process = SimpleNamespace(poll=lambda: 1) + module._http_client = StubHTTPClient( + httpx.Response(200, json={"object": "list", "data": []}, request=httpx.Request("GET", "http://127.0.0.1:8081/v1/models")) + ) + + transport = httpx.ASGITransport(app=module.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + response = await client.get("/health") + + self.assertEqual(response.status_code, 503) + self.assertEqual(response.json()["status"], "unavailable") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file