Files
vision/gateway/main.py
Gregor Klevze baf497b015 feat(maturity): add dedicated NSFW/maturity analysis service
- Add maturity/ service (FastAPI + Falconsai/nsfw_image_detection ViT classifier)
  - /analyze (URL) and /analyze/file (multipart upload) endpoints
  - Normalized response: maturity_label, confidence, score, labels,
    action_hint (safe/review/flag_high), advisory, threshold_used,
    analysis_time_ms, model, source
  - Configurable thresholds via MATURITY_THRESHOLD_MATURE / MATURITY_THRESHOLD_REVIEW
  - Reuses common/image_io for URL validation and file-size enforcement
  - Explicit 502/503 errors on failure — no silent safe fallback
  - Per-request structured logging (score, label, threshold path, elapsed ms)

- Update gateway/main.py
  - MATURITY_URL + MATURITY_ENABLED env vars
  - POST /analyze/maturity and POST /analyze/maturity/file endpoints
  - /health includes maturity service status
  - _assert_maturity_enabled() guard for clean 503 when disabled
  - All existing endpoints untouched (additive change)

- Update docker-compose.yml
  - Add maturity service with healthcheck (start_period: 90s)
  - Gateway environment: MATURITY_URL, MATURITY_ENABLED
  - Gateway depends_on: maturity (service_healthy)

- Update README.md and USAGE.md
  - Document maturity service, env vars, curl examples,
    full response schema table, action_hint logic, failure guidance
2026-04-11 17:29:26 +02:00

560 lines
20 KiB
Python

from __future__ import annotations
import asyncio
import logging
import os
import time
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional
import httpx
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Request
from fastapi.responses import JSONResponse, Response
from starlette.middleware.base import BaseHTTPMiddleware
from pydantic import BaseModel, Field
logger = logging.getLogger("gateway")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
CLIP_URL = os.getenv("CLIP_URL", "http://clip:8000")
BLIP_URL = os.getenv("BLIP_URL", "http://blip:8000")
YOLO_URL = os.getenv("YOLO_URL", "http://yolo:8000")
QDRANT_SVC_URL = os.getenv("QDRANT_SVC_URL", "http://qdrant-svc:8000")
CARD_RENDERER_URL = os.getenv("CARD_RENDERER_URL", "http://card-renderer:8000")
MATURITY_URL = os.getenv("MATURITY_URL", "http://maturity:8000")
MATURITY_ENABLED = os.getenv("MATURITY_ENABLED", "true").lower() not in ("0", "false", "no")
VISION_TIMEOUT = float(os.getenv("VISION_TIMEOUT", "20"))
# API key (set via env var `API_KEY`). If not set, gateway will reject requests.
API_KEY = os.getenv("API_KEY")
# ---------------------------------------------------------------------------
# Shared persistent HTTP client — created once at startup, reused across all
# requests. This eliminates per-request TCP connect + DNS latency (the main
# cause of 20 s first-request latency observed on /vectors/inspect).
# ---------------------------------------------------------------------------
_http_client: httpx.AsyncClient | None = None
def get_http_client() -> httpx.AsyncClient:
"""Return the shared httpx client. Raises if called before lifespan starts."""
if _http_client is None:
raise RuntimeError("HTTP client not initialised — lifespan not running")
return _http_client
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan: create shared HTTP client and warm upstream connections."""
global _http_client
t0 = time.perf_counter()
logger.info("gateway startup: creating shared HTTP client")
limits = httpx.Limits(
max_connections=100,
max_keepalive_connections=20,
keepalive_expiry=30,
)
_http_client = httpx.AsyncClient(
timeout=httpx.Timeout(VISION_TIMEOUT, connect=10),
limits=limits,
)
# Warm the qdrant-svc connection so the first real request does not pay
# the TCP handshake + DNS cost. Failure is non-fatal — the service may
# still be starting when the gateway starts.
try:
t_warm = time.perf_counter()
r = await _http_client.get(f"{QDRANT_SVC_URL}/health", timeout=8)
logger.info(
"gateway startup: qdrant-svc warm ping done status=%s elapsed_ms=%.1f",
r.status_code, (time.perf_counter() - t_warm) * 1000,
)
except Exception as exc:
logger.warning("gateway startup: qdrant-svc warm ping failed (non-fatal): %s", exc)
logger.info("gateway startup complete elapsed_ms=%.1f", (time.perf_counter() - t0) * 1000)
yield # application runs
logger.info("gateway shutdown: closing shared HTTP client")
await _http_client.aclose()
_http_client = None
class APIKeyMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# allow health and docs endpoints without API key
if request.url.path in ("/health", "/openapi.json", "/docs", "/redoc"):
return await call_next(request)
key = request.headers.get("x-api-key") or request.headers.get("X-API-Key")
if not API_KEY or key != API_KEY:
return JSONResponse(status_code=401, content={"detail": "Unauthorized"})
return await call_next(request)
app = FastAPI(title="Skinbase Vision Gateway", version="1.0.0", lifespan=lifespan)
app.add_middleware(APIKeyMiddleware)
class ClipRequest(BaseModel):
url: Optional[str] = None
limit: int = Field(default=5, ge=1, le=50)
threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
class BlipRequest(BaseModel):
url: Optional[str] = None
variants: int = Field(default=3, ge=0, le=10)
max_length: int = Field(default=60, ge=10, le=200)
class YoloRequest(BaseModel):
url: Optional[str] = None
conf: float = Field(default=0.25, ge=0.0, le=1.0)
class MaturityRequest(BaseModel):
url: Optional[str] = None
async def _get_health(base: str) -> Dict[str, Any]:
try:
r = await get_http_client().get(f"{base}/health", timeout=5)
return r.json() if r.status_code == 200 else {"status": "bad", "code": r.status_code}
except Exception:
return {"status": "unreachable"}
async def _post_json(url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
t0 = time.perf_counter()
try:
r = await get_http_client().post(url, json=payload)
except httpx.RequestError as e:
raise HTTPException(status_code=502, detail=f"Upstream request failed {url}: {str(e)}")
elapsed = (time.perf_counter() - t0) * 1000
logger.debug("POST %s status=%s elapsed_ms=%.1f", url, r.status_code, elapsed)
if r.status_code >= 400:
raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:1000]}")
try:
return r.json()
except Exception:
# upstream returned non-JSON (HTML error page or empty body)
raise HTTPException(status_code=502, detail=f"Upstream returned non-JSON at {url}: {r.status_code} {r.text[:1000]}")
async def _post_file(url: str, data: bytes, fields: Dict[str, Any]) -> Dict[str, Any]:
files = {"file": ("image", data, "application/octet-stream")}
t0 = time.perf_counter()
try:
r = await get_http_client().post(url, data={k: str(v) for k, v in fields.items()}, files=files)
except httpx.RequestError as e:
raise HTTPException(status_code=502, detail=f"Upstream request failed {url}: {str(e)}")
elapsed = (time.perf_counter() - t0) * 1000
logger.debug("POST(file) %s status=%s elapsed_ms=%.1f", url, r.status_code, elapsed)
if r.status_code >= 400:
raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:1000]}")
try:
return r.json()
except Exception:
raise HTTPException(status_code=502, detail=f"Upstream returned non-JSON at {url}: {r.status_code} {r.text[:1000]}")
async def _get_json(url: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
t0 = time.perf_counter()
try:
r = await get_http_client().get(url, params=params)
except httpx.RequestError as e:
raise HTTPException(status_code=502, detail=f"Upstream request failed {url}: {str(e)}")
elapsed = (time.perf_counter() - t0) * 1000
logger.debug("GET %s status=%s elapsed_ms=%.1f", url, r.status_code, elapsed)
if r.status_code >= 400:
raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:1000]}")
try:
return r.json()
except Exception:
raise HTTPException(status_code=502, detail=f"Upstream returned non-JSON at {url}: {r.status_code} {r.text[:1000]}")
@app.get("/health")
async def health():
health_checks = [
_get_health(CLIP_URL),
_get_health(BLIP_URL),
_get_health(YOLO_URL),
_get_health(QDRANT_SVC_URL),
]
if MATURITY_ENABLED:
health_checks.append(_get_health(MATURITY_URL))
results = await asyncio.gather(*health_checks)
services: Dict[str, Any] = {
"clip": results[0],
"blip": results[1],
"yolo": results[2],
"qdrant": results[3],
}
if MATURITY_ENABLED:
services["maturity"] = results[4]
return {"status": "ok", "services": services}
# ---- Individual analyze endpoints (URL) ----
@app.post("/analyze/clip")
async def analyze_clip(req: ClipRequest):
if not req.url:
raise HTTPException(400, "url is required")
return await _post_json(f"{CLIP_URL}/analyze", req.model_dump())
@app.post("/analyze/blip")
async def analyze_blip(req: BlipRequest):
if not req.url:
raise HTTPException(400, "url is required")
return await _post_json(f"{BLIP_URL}/caption", req.model_dump())
@app.post("/analyze/yolo")
async def analyze_yolo(req: YoloRequest):
if not req.url:
raise HTTPException(400, "url is required")
return await _post_json(f"{YOLO_URL}/detect", req.model_dump())
# ---- Individual analyze endpoints (file upload) ----
@app.post("/analyze/clip/file")
async def analyze_clip_file(
file: UploadFile = File(...),
limit: int = Form(5),
threshold: Optional[float] = Form(None),
):
data = await file.read()
fields: Dict[str, Any] = {"limit": int(limit)}
if threshold is not None:
fields["threshold"] = float(threshold)
return await _post_file(f"{CLIP_URL}/analyze/file", data, fields)
@app.post("/analyze/blip/file")
async def analyze_blip_file(
file: UploadFile = File(...),
variants: int = Form(3),
max_length: int = Form(60),
):
data = await file.read()
fields = {"variants": int(variants), "max_length": int(max_length)}
return await _post_file(f"{BLIP_URL}/caption/file", data, fields)
@app.post("/analyze/yolo/file")
async def analyze_yolo_file(
file: UploadFile = File(...),
conf: float = Form(0.25),
):
data = await file.read()
fields = {"conf": float(conf)}
return await _post_file(f"{YOLO_URL}/detect/file", data, fields)
@app.post("/analyze/all")
async def analyze_all(payload: Dict[str, Any]):
url = payload.get("url")
if not url:
raise HTTPException(400, "url is required")
clip_req = {"url": url, "limit": int(payload.get("limit", 5)), "threshold": payload.get("threshold")}
blip_req = {"url": url, "variants": int(payload.get("variants", 3)), "max_length": int(payload.get("max_length", 60))}
yolo_req = {"url": url, "conf": float(payload.get("conf", 0.25))}
clip_res, blip_res, yolo_res = await asyncio.gather(
_post_json(f"{CLIP_URL}/analyze", clip_req),
_post_json(f"{BLIP_URL}/caption", blip_req),
_post_json(f"{YOLO_URL}/detect", yolo_req),
)
return {"clip": clip_res, "blip": blip_res, "yolo": yolo_res}
# ---- Vector / Qdrant endpoints ----
@app.post("/vectors/upsert")
async def vectors_upsert(payload: Dict[str, Any]):
return await _post_json(f"{QDRANT_SVC_URL}/upsert", payload)
@app.post("/vectors/upsert/file")
async def vectors_upsert_file(
file: UploadFile = File(...),
id: Optional[str] = Form(None),
collection: Optional[str] = Form(None),
metadata_json: Optional[str] = Form(None),
):
data = await file.read()
fields: Dict[str, Any] = {}
if id is not None:
fields["id"] = id
if collection is not None:
fields["collection"] = collection
if metadata_json is not None:
fields["metadata_json"] = metadata_json
return await _post_file(f"{QDRANT_SVC_URL}/upsert/file", data, fields)
@app.post("/vectors/upsert/vector")
async def vectors_upsert_vector(payload: Dict[str, Any]):
return await _post_json(f"{QDRANT_SVC_URL}/upsert/vector", payload)
@app.post("/vectors/search")
async def vectors_search(payload: Dict[str, Any]):
return await _post_json(f"{QDRANT_SVC_URL}/search", payload)
@app.post("/vectors/search/file")
async def vectors_search_file(
file: UploadFile = File(...),
limit: int = Form(5),
score_threshold: Optional[float] = Form(None),
collection: Optional[str] = Form(None),
hnsw_ef: Optional[int] = Form(None),
exact: bool = Form(False),
indexed_only: bool = Form(False),
filter_metadata_json: Optional[str] = Form(None),
):
data = await file.read()
fields: Dict[str, Any] = {"limit": int(limit), "exact": exact, "indexed_only": indexed_only}
if score_threshold is not None:
fields["score_threshold"] = float(score_threshold)
if collection is not None:
fields["collection"] = collection
if hnsw_ef is not None:
fields["hnsw_ef"] = int(hnsw_ef)
if filter_metadata_json is not None:
fields["filter_metadata_json"] = filter_metadata_json
return await _post_file(f"{QDRANT_SVC_URL}/search/file", data, fields)
@app.post("/vectors/search/vector")
async def vectors_search_vector(payload: Dict[str, Any]):
return await _post_json(f"{QDRANT_SVC_URL}/search/vector", payload)
@app.post("/vectors/delete")
async def vectors_delete(payload: Dict[str, Any]):
return await _post_json(f"{QDRANT_SVC_URL}/delete", payload)
@app.get("/vectors/collections")
async def vectors_collections():
return await _get_json(f"{QDRANT_SVC_URL}/collections")
@app.post("/vectors/collections")
async def vectors_create_collection(payload: Dict[str, Any]):
return await _post_json(f"{QDRANT_SVC_URL}/collections", payload)
@app.get("/vectors/collections/{name}")
async def vectors_collection_info(name: str):
return await _get_json(f"{QDRANT_SVC_URL}/collections/{name}")
@app.get("/vectors/inspect")
async def vectors_inspect():
"""Full diagnostic summary for all Qdrant collections (HNSW, optimizer, payload indexes, RAM estimate)."""
t0 = time.perf_counter()
logger.info("vectors_inspect: start")
result = await _get_json(f"{QDRANT_SVC_URL}/inspect")
logger.info("vectors_inspect: done elapsed_ms=%.1f", (time.perf_counter() - t0) * 1000)
return result
@app.delete("/vectors/collections/{name}")
async def vectors_delete_collection(name: str):
try:
r = await get_http_client().delete(f"{QDRANT_SVC_URL}/collections/{name}")
except httpx.RequestError as exc:
raise HTTPException(status_code=502, detail=f"Upstream request failed: {exc}")
if r.status_code >= 400:
raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}")
return r.json()
@app.get("/vectors/points/{point_id}")
async def vectors_get_point(point_id: str, collection: Optional[str] = None):
params = {}
if collection:
params["collection"] = collection
return await _get_json(f"{QDRANT_SVC_URL}/points/{point_id}", params=params)
@app.get("/vectors/points/by-original-id/{original_id}")
async def vectors_get_point_by_original_id(original_id: str, collection: Optional[str] = None):
params = {}
if collection:
params["collection"] = collection
return await _get_json(f"{QDRANT_SVC_URL}/points/by-original-id/{original_id}", params=params)
# ---- File-based universal analyze ----
@app.post("/analyze/all/file")
async def analyze_all_file(
file: UploadFile = File(...),
limit: int = Form(5),
variants: int = Form(3),
conf: float = Form(0.25),
max_length: int = Form(60),
):
data = await file.read()
clip_res, blip_res, yolo_res = await asyncio.gather(
_post_file(f"{CLIP_URL}/analyze/file", data, {"limit": limit}),
_post_file(f"{BLIP_URL}/caption/file", data, {"variants": variants, "max_length": max_length}),
_post_file(f"{YOLO_URL}/detect/file", data, {"conf": conf}),
)
return {"clip": clip_res, "blip": blip_res, "yolo": yolo_res}
# ---- Maturity / NSFW analysis endpoints ----
def _assert_maturity_enabled() -> None:
if not MATURITY_ENABLED:
raise HTTPException(status_code=503, detail="Maturity service is disabled")
@app.post("/analyze/maturity")
async def analyze_maturity(req: MaturityRequest):
"""Analyze an image URL for maturity / NSFW content.
Returns a normalized maturity signal including maturity_label (safe/mature),
confidence, score, optional sublabels, and an action_hint for Nova moderation.
"""
_assert_maturity_enabled()
if not req.url:
raise HTTPException(status_code=400, detail="url is required")
logger.info("analyze_maturity: url=%s", req.url)
return await _post_json(f"{MATURITY_URL}/analyze", req.model_dump())
@app.post("/analyze/maturity/file")
async def analyze_maturity_file(file: UploadFile = File(...)):
"""Analyze an uploaded image file for maturity / NSFW content.
Returns the same normalized maturity signal as /analyze/maturity.
"""
_assert_maturity_enabled()
data = await file.read()
logger.info("analyze_maturity_file: filename=%s size=%d", file.filename, len(data))
return await _post_file(f"{MATURITY_URL}/analyze/file", data, {})
# ---- Card renderer endpoints ----
@app.get("/cards/templates")
async def cards_templates():
"""List available card templates."""
return await _get_json(f"{CARD_RENDERER_URL}/templates")
@app.post("/cards/render")
async def cards_render(payload: Dict[str, Any]):
"""Render a Nova card from a remote image URL. Returns binary image bytes."""
try:
resp = await get_http_client().post(f"{CARD_RENDERER_URL}/render", json=payload)
except httpx.RequestError as exc:
raise HTTPException(status_code=502, detail=f"card-renderer unreachable: {exc}")
if resp.status_code >= 400:
raise HTTPException(status_code=502, detail=f"card-renderer error {resp.status_code}: {resp.text[:1000]}")
return Response(
content=resp.content,
media_type=resp.headers.get("content-type", "image/webp"),
)
@app.post("/cards/render/file")
async def cards_render_file(
file: UploadFile = File(...),
template: str = Form("nova-artwork-v1"),
width: int = Form(1200),
height: int = Form(630),
output: str = Form("webp"),
quality: int = Form(90),
title: Optional[str] = Form(None),
subtitle: Optional[str] = Form(None),
username: Optional[str] = Form(None),
category: Optional[str] = Form(None),
tags_json: Optional[str] = Form(None),
show_logo: bool = Form(True),
):
"""Render a Nova card from an uploaded image file. Returns binary image bytes."""
data = await file.read()
fields: Dict[str, Any] = {
"template": template,
"width": width,
"height": height,
"output": output,
"quality": quality,
"show_logo": show_logo,
}
if title is not None:
fields["title"] = title
if subtitle is not None:
fields["subtitle"] = subtitle
if username is not None:
fields["username"] = username
if category is not None:
fields["category"] = category
if tags_json is not None:
fields["tags_json"] = tags_json
upload_files = {"file": (file.filename or "image", data, file.content_type or "application/octet-stream")}
try:
resp = await get_http_client().post(
f"{CARD_RENDERER_URL}/render/file",
data={k: str(v) for k, v in fields.items()},
files=upload_files,
)
except httpx.RequestError as exc:
raise HTTPException(status_code=502, detail=f"card-renderer unreachable: {exc}")
if resp.status_code >= 400:
raise HTTPException(status_code=502, detail=f"card-renderer error {resp.status_code}: {resp.text[:1000]}")
return Response(
content=resp.content,
media_type=resp.headers.get("content-type", "image/webp"),
)
@app.post("/cards/render/meta")
async def cards_render_meta(payload: Dict[str, Any]):
"""Return crop and layout metadata for a card render (no image produced)."""
return await _post_json(f"{CARD_RENDERER_URL}/render/meta", payload)
# ---- Qdrant administration endpoints (index management + collection config) ----
@app.get("/vectors/collections/{name}/indexes")
async def vectors_collection_indexes(name: str):
"""List payload indexes for a collection."""
return await _get_json(f"{QDRANT_SVC_URL}/collections/{name}/indexes")
@app.post("/vectors/collections/{name}/indexes")
async def vectors_create_payload_index(name: str, payload: Dict[str, Any]):
"""Create a payload index on a field in a collection."""
return await _post_json(f"{QDRANT_SVC_URL}/collections/{name}/indexes", payload)
@app.post("/vectors/collections/{name}/ensure-indexes")
async def vectors_ensure_indexes(name: str, payload: Dict[str, Any]):
"""Idempotently ensure payload indexes exist for a list of fields."""
return await _post_json(f"{QDRANT_SVC_URL}/collections/{name}/ensure-indexes", payload)
@app.post("/vectors/collections/{name}/configure")
async def vectors_configure_collection(name: str, payload: Dict[str, Any]):
"""Update HNSW and optimizer configuration for a collection."""
return await _post_json(f"{QDRANT_SVC_URL}/collections/{name}/configure", payload)