feat(maturity): add dedicated NSFW/maturity analysis service
- Add maturity/ service (FastAPI + Falconsai/nsfw_image_detection ViT classifier)
- /analyze (URL) and /analyze/file (multipart upload) endpoints
- Normalized response: maturity_label, confidence, score, labels,
action_hint (safe/review/flag_high), advisory, threshold_used,
analysis_time_ms, model, source
- Configurable thresholds via MATURITY_THRESHOLD_MATURE / MATURITY_THRESHOLD_REVIEW
- Reuses common/image_io for URL validation and file-size enforcement
- Explicit 502/503 errors on failure — no silent safe fallback
- Per-request structured logging (score, label, threshold path, elapsed ms)
- Update gateway/main.py
- MATURITY_URL + MATURITY_ENABLED env vars
- POST /analyze/maturity and POST /analyze/maturity/file endpoints
- /health includes maturity service status
- _assert_maturity_enabled() guard for clean 503 when disabled
- All existing endpoints untouched (additive change)
- Update docker-compose.yml
- Add maturity service with healthcheck (start_period: 90s)
- Gateway environment: MATURITY_URL, MATURITY_ENABLED
- Gateway depends_on: maturity (service_healthy)
- Update README.md and USAGE.md
- Document maturity service, env vars, curl examples,
full response schema table, action_hint logic, failure guidance
This commit is contained in:
52
README.md
52
README.md
@@ -1,6 +1,6 @@
|
||||
# Skinbase Vision Stack (CLIP + BLIP + YOLO + Qdrant + Card Renderer) – Dockerized FastAPI
|
||||
# Skinbase Vision Stack (CLIP + BLIP + YOLO + Qdrant + Card Renderer + Maturity) – Dockerized FastAPI
|
||||
|
||||
This repository provides **five standalone vision services** (CLIP / BLIP / YOLO / Qdrant / Card Renderer)
|
||||
This repository provides **six standalone vision services** (CLIP / BLIP / YOLO / Qdrant / Card Renderer / Maturity)
|
||||
and a **Gateway API** that can call them individually or together.
|
||||
|
||||
## Services & Ports
|
||||
@@ -12,6 +12,7 @@ and a **Gateway API** that can call them individually or together.
|
||||
- `qdrant`: vector DB (port `6333` exposed for direct access)
|
||||
- `qdrant-svc`: internal Qdrant API wrapper
|
||||
- `card-renderer`: internal card rendering service
|
||||
- `maturity`: internal NSFW/maturity classifier service
|
||||
|
||||
## Run
|
||||
|
||||
@@ -30,6 +31,15 @@ HUGGINGFACE_TOKEN=your_huggingface_token_here
|
||||
|
||||
`HUGGINGFACE_TOKEN` is required when the configured BLIP model is private, gated, or otherwise requires Hugging Face authentication.
|
||||
|
||||
Optional maturity configuration (override in `.env` if needed):
|
||||
|
||||
```bash
|
||||
MATURITY_MODEL=Falconsai/nsfw_image_detection
|
||||
MATURITY_THRESHOLD_MATURE=0.80
|
||||
MATURITY_THRESHOLD_REVIEW=0.60
|
||||
MATURITY_ENABLED=true
|
||||
```
|
||||
|
||||
Service startup now waits on container healthchecks, so first boot may take longer while models finish loading.
|
||||
|
||||
## Health
|
||||
@@ -96,6 +106,41 @@ curl -H "X-API-Key: <your-api-key>" -X POST https://vision.klevze.net/analyze/yo
|
||||
-F "conf=0.25"
|
||||
```
|
||||
|
||||
## Maturity / NSFW analysis
|
||||
|
||||
Analyzes an image and returns a normalized maturity signal for Nova moderation workflows.
|
||||
|
||||
### Analyze by URL
|
||||
```bash
|
||||
curl -H "X-API-Key: <your-api-key>" -X POST https://vision.klevze.net/analyze/maturity \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"url":"https://files.skinbase.org/img/aa/bb/cc/md.webp"}'
|
||||
```
|
||||
|
||||
### Analyze from file upload
|
||||
```bash
|
||||
curl -H "X-API-Key: <your-api-key>" -X POST https://vision.klevze.net/analyze/maturity/file \
|
||||
-F "file=@/path/to/image.webp"
|
||||
```
|
||||
|
||||
Example response:
|
||||
```json
|
||||
{
|
||||
"maturity_label": "mature",
|
||||
"confidence": 0.94,
|
||||
"score": 0.94,
|
||||
"labels": ["nsfw"],
|
||||
"model": "Falconsai/nsfw_image_detection",
|
||||
"threshold_used": 0.80,
|
||||
"analysis_time_ms": 183.0,
|
||||
"source": "maturity-service",
|
||||
"action_hint": "flag_high",
|
||||
"advisory": "High-confidence mature content detected"
|
||||
}
|
||||
```
|
||||
|
||||
`action_hint` values: `safe`, `review`, `flag_high`. Nova should use these to decide blur/queue/flag behaviour.
|
||||
|
||||
## Vector DB (Qdrant) via gateway
|
||||
|
||||
Qdrant point IDs must be either:
|
||||
@@ -228,8 +273,9 @@ curl -H "X-API-Key: <your-api-key>" -X POST https://vision.klevze.net/cards/rend
|
||||
|
||||
## Notes
|
||||
|
||||
- This is a **starter scaffold**. Models are loaded at service startup.
|
||||
- Models are loaded at service startup; initial container start can take 1–2 minutes as model weights are downloaded.
|
||||
- Qdrant data is persisted in the project folder at `./data/qdrant`, so it survives container restarts and recreates.
|
||||
- Remote image URLs are restricted to public `http`/`https` hosts. Localhost, private IP ranges, and non-image content types are rejected.
|
||||
- The maturity service uses `Falconsai/nsfw_image_detection` (ViT-based). Thresholds are configurable via `.env`. The model handles photos and stylized digital art but should be calibrated against real Skinbase content before production use.
|
||||
- For production: add auth, rate limits, and restrict gateway exposure (private network).
|
||||
- GPU: you can add NVIDIA runtime later (compose profiles) if needed.
|
||||
|
||||
84
USAGE.md
84
USAGE.md
@@ -1,10 +1,10 @@
|
||||
# Skinbase Vision Stack — Usage Guide
|
||||
|
||||
This document explains how to run and use the Skinbase Vision Stack (Gateway + CLIP, BLIP, YOLO, Qdrant services).
|
||||
This document explains how to run and use the Skinbase Vision Stack (Gateway + CLIP, BLIP, YOLO, Qdrant, Card Renderer, Maturity services).
|
||||
|
||||
## Overview
|
||||
|
||||
- Services: `gateway`, `clip`, `blip`, `yolo`, `qdrant`, `qdrant-svc`, `card-renderer` (FastAPI each, except `qdrant` which is the official Qdrant DB).
|
||||
- Services: `gateway`, `clip`, `blip`, `yolo`, `qdrant`, `qdrant-svc`, `card-renderer`, `maturity` (FastAPI each, except `qdrant` which is the official Qdrant DB).
|
||||
- Gateway is the public API endpoint; the other services are internal.
|
||||
|
||||
## Model overview
|
||||
@@ -19,6 +19,8 @@ This document explains how to run and use the Skinbase Vision Stack (Gateway + C
|
||||
|
||||
- **Card Renderer**: Generates branded social-card images (e.g. Open Graph previews) from artwork images. Applies smart center-weighted cropping, gradient overlays, title/username/tag text, and an optional logo. Returns binary image bytes (WebP by default). Template: `nova-artwork-v1`.
|
||||
|
||||
- **Maturity**: Dedicated NSFW/maturity classifier. Accepts an image and returns a normalized safety signal including `maturity_label` (`safe`/`mature`), `confidence`, raw `score`, optional sublabels (e.g. `nsfw`), and an `action_hint` (`safe`, `review`, `flag_high`) designed for Nova moderation workflows. Powered by `Falconsai/nsfw_image_detection` (ViT-based, HuggingFace). Thresholds are configurable via environment variables.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Docker Desktop (with `docker compose`) or a Docker environment.
|
||||
@@ -40,6 +42,19 @@ Notes:
|
||||
- `HUGGINGFACE_TOKEN` is required if the configured BLIP model requires Hugging Face authentication.
|
||||
- Startup uses container healthchecks, so initial boot can take longer while models download and warm up.
|
||||
|
||||
Optional maturity configuration (can be added to `.env` to override defaults):
|
||||
|
||||
```bash
|
||||
MATURITY_MODEL=Falconsai/nsfw_image_detection
|
||||
MATURITY_THRESHOLD_MATURE=0.80
|
||||
MATURITY_THRESHOLD_REVIEW=0.60
|
||||
MATURITY_ENABLED=true
|
||||
```
|
||||
|
||||
- `MATURITY_THRESHOLD_MATURE`: score above this → `mature` + `flag_high` (default `0.80`).
|
||||
- `MATURITY_THRESHOLD_REVIEW`: score above this but below mature threshold → `mature` + `review` (default `0.60`).
|
||||
- `MATURITY_ENABLED`: set to `false` to disable maturity endpoints at the gateway without removing the service.
|
||||
|
||||
Run from repository root:
|
||||
|
||||
```bash
|
||||
@@ -168,9 +183,65 @@ Parameters:
|
||||
|
||||
Return: detected objects with `class`, `confidence`, and `bbox` (bounding box coordinates).
|
||||
|
||||
### Qdrant — vector storage & similarity search
|
||||
### Maturity — NSFW / maturity analysis
|
||||
|
||||
The Qdrant integration lets you store image embeddings and find visually similar images. Embeddings are generated automatically by the CLIP service.
|
||||
Analyzes an image for mature or NSFW content and returns a structured signal intended for Nova moderation workflows.
|
||||
|
||||
URL request:
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/maturity \
|
||||
-H "X-API-Key: <your-api-key>" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"url":"https://files.skinbase.org/img/aa/bb/cc/md.webp"}'
|
||||
```
|
||||
|
||||
File upload:
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/maturity/file \
|
||||
-H "X-API-Key: <your-api-key>" \
|
||||
-F "file=@/path/to/image.webp"
|
||||
```
|
||||
|
||||
Example response:
|
||||
|
||||
```json
|
||||
{
|
||||
"maturity_label": "mature",
|
||||
"confidence": 0.94,
|
||||
"score": 0.94,
|
||||
"labels": ["nsfw"],
|
||||
"model": "Falconsai/nsfw_image_detection",
|
||||
"threshold_used": 0.80,
|
||||
"analysis_time_ms": 183.0,
|
||||
"source": "maturity-service",
|
||||
"action_hint": "flag_high",
|
||||
"advisory": "High-confidence mature content detected"
|
||||
}
|
||||
```
|
||||
|
||||
Response fields:
|
||||
|
||||
| Field | Type | Description |
|
||||
|---|---|---|
|
||||
| `maturity_label` | string | `safe` or `mature` |
|
||||
| `confidence` | float | Confidence in the label decision (0–1). For `safe`, this is `1 - score`. |
|
||||
| `score` | float | Raw NSFW probability from the model (0–1). |
|
||||
| `labels` | array | Sublabels when mature: currently `["nsfw"]`. Empty for safe results. |
|
||||
| `model` | string | Model identifier / HuggingFace model ID. |
|
||||
| `threshold_used` | float | The threshold value that determined the label. |
|
||||
| `analysis_time_ms` | float | Inference time in milliseconds. |
|
||||
| `source` | string | Always `maturity-service`. |
|
||||
| `action_hint` | string | `safe`, `review`, or `flag_high`. Use this in Nova to drive blur/queue/flag decisions. |
|
||||
| `advisory` | string | Short human-readable explanation. |
|
||||
|
||||
`action_hint` decision logic:
|
||||
- `flag_high`: score ≥ `MATURITY_THRESHOLD_MATURE` (default 0.80) — high-confidence mature, flag for moderation.
|
||||
- `review`: score ≥ `MATURITY_THRESHOLD_REVIEW` (default 0.60) but below mature threshold — possible mature, queue for human review.
|
||||
- `safe`: score below both thresholds — content appears safe.
|
||||
|
||||
If the maturity service is unavailable the gateway returns a `502` or `503` error. **Nova must not treat a gateway failure as a `safe` result** — retry or queue for later processing. store image embeddings and find visually similar images. Embeddings are generated automatically by the CLIP service.
|
||||
|
||||
Qdrant point IDs must be either an unsigned integer or a UUID string. If you send another string value, the wrapper may replace it with a generated UUID and store the original value in metadata as `_original_id`.
|
||||
|
||||
@@ -457,7 +528,9 @@ uvicorn main:app --host 0.0.0.0 --port 8000
|
||||
- Qdrant upsert error about invalid point ID: use a UUID or unsigned integer for `id`, or omit it and use the returned generated `id`.
|
||||
- Image URL rejected before download: the URL may point to localhost, a private IP, a non-`http/https` scheme, or a non-image content type.
|
||||
- High memory / OOM: increase host memory or reduce model footprint; consider GPUs.
|
||||
- Slow startup: model weights load on service startup — expect extra time.
|
||||
- Slow startup: model weights load on service startup — expect extra time. The maturity service (`start_period: 90s`) may take longer on first boot as it downloads the classifier weights (~330 MB). Mount `~/.cache/huggingface` as a volume to persist across rebuilds.
|
||||
- Maturity endpoint returns `503`: `MATURITY_ENABLED` is set to `false` in environment configuration.
|
||||
- Maturity endpoint returns `502`: the maturity container is unhealthy or still starting up; wait and retry.
|
||||
|
||||
## Extending
|
||||
|
||||
@@ -469,6 +542,7 @@ uvicorn main:app --host 0.0.0.0 --port 8000
|
||||
- `docker-compose.yml` — composition and service definitions.
|
||||
- `gateway/` — gateway FastAPI server.
|
||||
- `clip/`, `blip/`, `yolo/` — service implementations and Dockerfiles.
|
||||
- `maturity/` — NSFW/maturity classifier service (ViT-based, HuggingFace `Falconsai/nsfw_image_detection`).
|
||||
- `qdrant/` — Qdrant API wrapper service (FastAPI).
|
||||
- `card-renderer/` — card rendering service (FastAPI).
|
||||
- `common/` — shared helpers (e.g., image I/O).
|
||||
|
||||
@@ -13,6 +13,8 @@ services:
|
||||
- YOLO_URL=http://yolo:8000
|
||||
- QDRANT_SVC_URL=http://qdrant-svc:8000
|
||||
- CARD_RENDERER_URL=http://card-renderer:8000
|
||||
- MATURITY_URL=http://maturity:8000
|
||||
- MATURITY_ENABLED=true
|
||||
- API_KEY=${API_KEY}
|
||||
- VISION_TIMEOUT=300
|
||||
- MAX_IMAGE_BYTES=52428800
|
||||
@@ -27,6 +29,8 @@ services:
|
||||
condition: service_healthy
|
||||
card-renderer:
|
||||
condition: service_healthy
|
||||
maturity:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health', timeout=5).read()"]
|
||||
interval: 30s
|
||||
@@ -131,3 +135,19 @@ services:
|
||||
retries: 5
|
||||
start_period: 60s
|
||||
|
||||
maturity:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: maturity/Dockerfile
|
||||
environment:
|
||||
- MATURITY_MODEL=${MATURITY_MODEL:-Falconsai/nsfw_image_detection}
|
||||
- MATURITY_THRESHOLD_MATURE=${MATURITY_THRESHOLD_MATURE:-0.80}
|
||||
- MATURITY_THRESHOLD_REVIEW=${MATURITY_THRESHOLD_REVIEW:-0.60}
|
||||
- MAX_IMAGE_BYTES=52428800
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health', timeout=5).read()"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
start_period: 90s
|
||||
|
||||
|
||||
272
gateway/main.py
272
gateway/main.py
@@ -1,7 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
@@ -10,16 +13,76 @@ from fastapi.responses import JSONResponse, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger("gateway")
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
|
||||
CLIP_URL = os.getenv("CLIP_URL", "http://clip:8000")
|
||||
BLIP_URL = os.getenv("BLIP_URL", "http://blip:8000")
|
||||
YOLO_URL = os.getenv("YOLO_URL", "http://yolo:8000")
|
||||
QDRANT_SVC_URL = os.getenv("QDRANT_SVC_URL", "http://qdrant-svc:8000")
|
||||
CARD_RENDERER_URL = os.getenv("CARD_RENDERER_URL", "http://card-renderer:8000")
|
||||
MATURITY_URL = os.getenv("MATURITY_URL", "http://maturity:8000")
|
||||
MATURITY_ENABLED = os.getenv("MATURITY_ENABLED", "true").lower() not in ("0", "false", "no")
|
||||
VISION_TIMEOUT = float(os.getenv("VISION_TIMEOUT", "20"))
|
||||
|
||||
# API key (set via env var `API_KEY`). If not set, gateway will reject requests.
|
||||
API_KEY = os.getenv("API_KEY")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared persistent HTTP client — created once at startup, reused across all
|
||||
# requests. This eliminates per-request TCP connect + DNS latency (the main
|
||||
# cause of 20 s first-request latency observed on /vectors/inspect).
|
||||
# ---------------------------------------------------------------------------
|
||||
_http_client: httpx.AsyncClient | None = None
|
||||
|
||||
|
||||
def get_http_client() -> httpx.AsyncClient:
|
||||
"""Return the shared httpx client. Raises if called before lifespan starts."""
|
||||
if _http_client is None:
|
||||
raise RuntimeError("HTTP client not initialised — lifespan not running")
|
||||
return _http_client
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan: create shared HTTP client and warm upstream connections."""
|
||||
global _http_client
|
||||
|
||||
t0 = time.perf_counter()
|
||||
logger.info("gateway startup: creating shared HTTP client")
|
||||
|
||||
limits = httpx.Limits(
|
||||
max_connections=100,
|
||||
max_keepalive_connections=20,
|
||||
keepalive_expiry=30,
|
||||
)
|
||||
_http_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(VISION_TIMEOUT, connect=10),
|
||||
limits=limits,
|
||||
)
|
||||
|
||||
# Warm the qdrant-svc connection so the first real request does not pay
|
||||
# the TCP handshake + DNS cost. Failure is non-fatal — the service may
|
||||
# still be starting when the gateway starts.
|
||||
try:
|
||||
t_warm = time.perf_counter()
|
||||
r = await _http_client.get(f"{QDRANT_SVC_URL}/health", timeout=8)
|
||||
logger.info(
|
||||
"gateway startup: qdrant-svc warm ping done status=%s elapsed_ms=%.1f",
|
||||
r.status_code, (time.perf_counter() - t_warm) * 1000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("gateway startup: qdrant-svc warm ping failed (non-fatal): %s", exc)
|
||||
|
||||
logger.info("gateway startup complete elapsed_ms=%.1f", (time.perf_counter() - t0) * 1000)
|
||||
|
||||
yield # application runs
|
||||
|
||||
logger.info("gateway shutdown: closing shared HTTP client")
|
||||
await _http_client.aclose()
|
||||
_http_client = None
|
||||
|
||||
|
||||
class APIKeyMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# allow health and docs endpoints without API key
|
||||
@@ -30,7 +93,7 @@ class APIKeyMiddleware(BaseHTTPMiddleware):
|
||||
return JSONResponse(status_code=401, content={"detail": "Unauthorized"})
|
||||
return await call_next(request)
|
||||
|
||||
app = FastAPI(title="Skinbase Vision Gateway", version="1.0.0")
|
||||
app = FastAPI(title="Skinbase Vision Gateway", version="1.0.0", lifespan=lifespan)
|
||||
app.add_middleware(APIKeyMiddleware)
|
||||
|
||||
|
||||
@@ -51,19 +114,26 @@ class YoloRequest(BaseModel):
|
||||
conf: float = Field(default=0.25, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
async def _get_health(client: httpx.AsyncClient, base: str) -> Dict[str, Any]:
|
||||
class MaturityRequest(BaseModel):
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
async def _get_health(base: str) -> Dict[str, Any]:
|
||||
try:
|
||||
r = await client.get(f"{base}/health")
|
||||
r = await get_http_client().get(f"{base}/health", timeout=5)
|
||||
return r.json() if r.status_code == 200 else {"status": "bad", "code": r.status_code}
|
||||
except Exception:
|
||||
return {"status": "unreachable"}
|
||||
|
||||
|
||||
async def _post_json(client: httpx.AsyncClient, url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def _post_json(url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
r = await client.post(url, json=payload)
|
||||
r = await get_http_client().post(url, json=payload)
|
||||
except httpx.RequestError as e:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream request failed {url}: {str(e)}")
|
||||
elapsed = (time.perf_counter() - t0) * 1000
|
||||
logger.debug("POST %s status=%s elapsed_ms=%.1f", url, r.status_code, elapsed)
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:1000]}")
|
||||
try:
|
||||
@@ -73,12 +143,15 @@ async def _post_json(client: httpx.AsyncClient, url: str, payload: Dict[str, Any
|
||||
raise HTTPException(status_code=502, detail=f"Upstream returned non-JSON at {url}: {r.status_code} {r.text[:1000]}")
|
||||
|
||||
|
||||
async def _post_file(client: httpx.AsyncClient, url: str, data: bytes, fields: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def _post_file(url: str, data: bytes, fields: Dict[str, Any]) -> Dict[str, Any]:
|
||||
files = {"file": ("image", data, "application/octet-stream")}
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
r = await client.post(url, data={k: str(v) for k, v in fields.items()}, files=files)
|
||||
r = await get_http_client().post(url, data={k: str(v) for k, v in fields.items()}, files=files)
|
||||
except httpx.RequestError as e:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream request failed {url}: {str(e)}")
|
||||
elapsed = (time.perf_counter() - t0) * 1000
|
||||
logger.debug("POST(file) %s status=%s elapsed_ms=%.1f", url, r.status_code, elapsed)
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:1000]}")
|
||||
try:
|
||||
@@ -87,11 +160,14 @@ async def _post_file(client: httpx.AsyncClient, url: str, data: bytes, fields: D
|
||||
raise HTTPException(status_code=502, detail=f"Upstream returned non-JSON at {url}: {r.status_code} {r.text[:1000]}")
|
||||
|
||||
|
||||
async def _get_json(client: httpx.AsyncClient, url: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
async def _get_json(url: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
r = await client.get(url, params=params)
|
||||
r = await get_http_client().get(url, params=params)
|
||||
except httpx.RequestError as e:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream request failed {url}: {str(e)}")
|
||||
elapsed = (time.perf_counter() - t0) * 1000
|
||||
logger.debug("GET %s status=%s elapsed_ms=%.1f", url, r.status_code, elapsed)
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:1000]}")
|
||||
try:
|
||||
@@ -102,14 +178,25 @@ async def _get_json(client: httpx.AsyncClient, url: str, params: Optional[Dict[s
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
clip_h, blip_h, yolo_h, qdrant_h = await asyncio.gather(
|
||||
_get_health(client, CLIP_URL),
|
||||
_get_health(client, BLIP_URL),
|
||||
_get_health(client, YOLO_URL),
|
||||
_get_health(client, QDRANT_SVC_URL),
|
||||
)
|
||||
return {"status": "ok", "services": {"clip": clip_h, "blip": blip_h, "yolo": yolo_h, "qdrant": qdrant_h}}
|
||||
health_checks = [
|
||||
_get_health(CLIP_URL),
|
||||
_get_health(BLIP_URL),
|
||||
_get_health(YOLO_URL),
|
||||
_get_health(QDRANT_SVC_URL),
|
||||
]
|
||||
if MATURITY_ENABLED:
|
||||
health_checks.append(_get_health(MATURITY_URL))
|
||||
|
||||
results = await asyncio.gather(*health_checks)
|
||||
services: Dict[str, Any] = {
|
||||
"clip": results[0],
|
||||
"blip": results[1],
|
||||
"yolo": results[2],
|
||||
"qdrant": results[3],
|
||||
}
|
||||
if MATURITY_ENABLED:
|
||||
services["maturity"] = results[4]
|
||||
return {"status": "ok", "services": services}
|
||||
|
||||
|
||||
# ---- Individual analyze endpoints (URL) ----
|
||||
@@ -118,24 +205,21 @@ async def health():
|
||||
async def analyze_clip(req: ClipRequest):
|
||||
if not req.url:
|
||||
raise HTTPException(400, "url is required")
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{CLIP_URL}/analyze", req.model_dump())
|
||||
return await _post_json(f"{CLIP_URL}/analyze", req.model_dump())
|
||||
|
||||
|
||||
@app.post("/analyze/blip")
|
||||
async def analyze_blip(req: BlipRequest):
|
||||
if not req.url:
|
||||
raise HTTPException(400, "url is required")
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{BLIP_URL}/caption", req.model_dump())
|
||||
return await _post_json(f"{BLIP_URL}/caption", req.model_dump())
|
||||
|
||||
|
||||
@app.post("/analyze/yolo")
|
||||
async def analyze_yolo(req: YoloRequest):
|
||||
if not req.url:
|
||||
raise HTTPException(400, "url is required")
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{YOLO_URL}/detect", req.model_dump())
|
||||
return await _post_json(f"{YOLO_URL}/detect", req.model_dump())
|
||||
|
||||
|
||||
# ---- Individual analyze endpoints (file upload) ----
|
||||
@@ -151,8 +235,7 @@ async def analyze_clip_file(
|
||||
fields: Dict[str, Any] = {"limit": int(limit)}
|
||||
if threshold is not None:
|
||||
fields["threshold"] = float(threshold)
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_file(client, f"{CLIP_URL}/analyze/file", data, fields)
|
||||
return await _post_file(f"{CLIP_URL}/analyze/file", data, fields)
|
||||
|
||||
|
||||
@app.post("/analyze/blip/file")
|
||||
@@ -163,8 +246,7 @@ async def analyze_blip_file(
|
||||
):
|
||||
data = await file.read()
|
||||
fields = {"variants": int(variants), "max_length": int(max_length)}
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_file(client, f"{BLIP_URL}/caption/file", data, fields)
|
||||
return await _post_file(f"{BLIP_URL}/caption/file", data, fields)
|
||||
|
||||
|
||||
@app.post("/analyze/yolo/file")
|
||||
@@ -174,8 +256,7 @@ async def analyze_yolo_file(
|
||||
):
|
||||
data = await file.read()
|
||||
fields = {"conf": float(conf)}
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_file(client, f"{YOLO_URL}/detect/file", data, fields)
|
||||
return await _post_file(f"{YOLO_URL}/detect/file", data, fields)
|
||||
|
||||
|
||||
@app.post("/analyze/all")
|
||||
@@ -188,13 +269,11 @@ async def analyze_all(payload: Dict[str, Any]):
|
||||
blip_req = {"url": url, "variants": int(payload.get("variants", 3)), "max_length": int(payload.get("max_length", 60))}
|
||||
yolo_req = {"url": url, "conf": float(payload.get("conf", 0.25))}
|
||||
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
clip_task = _post_json(client, f"{CLIP_URL}/analyze", clip_req)
|
||||
blip_task = _post_json(client, f"{BLIP_URL}/caption", blip_req)
|
||||
yolo_task = _post_json(client, f"{YOLO_URL}/detect", yolo_req)
|
||||
|
||||
clip_res, blip_res, yolo_res = await asyncio.gather(clip_task, blip_task, yolo_task)
|
||||
|
||||
clip_res, blip_res, yolo_res = await asyncio.gather(
|
||||
_post_json(f"{CLIP_URL}/analyze", clip_req),
|
||||
_post_json(f"{BLIP_URL}/caption", blip_req),
|
||||
_post_json(f"{YOLO_URL}/detect", yolo_req),
|
||||
)
|
||||
return {"clip": clip_res, "blip": blip_res, "yolo": yolo_res}
|
||||
|
||||
|
||||
@@ -202,8 +281,7 @@ async def analyze_all(payload: Dict[str, Any]):
|
||||
|
||||
@app.post("/vectors/upsert")
|
||||
async def vectors_upsert(payload: Dict[str, Any]):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/upsert", payload)
|
||||
return await _post_json(f"{QDRANT_SVC_URL}/upsert", payload)
|
||||
|
||||
|
||||
@app.post("/vectors/upsert/file")
|
||||
@@ -221,20 +299,17 @@ async def vectors_upsert_file(
|
||||
fields["collection"] = collection
|
||||
if metadata_json is not None:
|
||||
fields["metadata_json"] = metadata_json
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_file(client, f"{QDRANT_SVC_URL}/upsert/file", data, fields)
|
||||
return await _post_file(f"{QDRANT_SVC_URL}/upsert/file", data, fields)
|
||||
|
||||
|
||||
@app.post("/vectors/upsert/vector")
|
||||
async def vectors_upsert_vector(payload: Dict[str, Any]):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/upsert/vector", payload)
|
||||
return await _post_json(f"{QDRANT_SVC_URL}/upsert/vector", payload)
|
||||
|
||||
|
||||
@app.post("/vectors/search")
|
||||
async def vectors_search(payload: Dict[str, Any]):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/search", payload)
|
||||
return await _post_json(f"{QDRANT_SVC_URL}/search", payload)
|
||||
|
||||
|
||||
@app.post("/vectors/search/file")
|
||||
@@ -258,51 +333,50 @@ async def vectors_search_file(
|
||||
fields["hnsw_ef"] = int(hnsw_ef)
|
||||
if filter_metadata_json is not None:
|
||||
fields["filter_metadata_json"] = filter_metadata_json
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_file(client, f"{QDRANT_SVC_URL}/search/file", data, fields)
|
||||
return await _post_file(f"{QDRANT_SVC_URL}/search/file", data, fields)
|
||||
|
||||
|
||||
@app.post("/vectors/search/vector")
|
||||
async def vectors_search_vector(payload: Dict[str, Any]):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/search/vector", payload)
|
||||
return await _post_json(f"{QDRANT_SVC_URL}/search/vector", payload)
|
||||
|
||||
|
||||
@app.post("/vectors/delete")
|
||||
async def vectors_delete(payload: Dict[str, Any]):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/delete", payload)
|
||||
return await _post_json(f"{QDRANT_SVC_URL}/delete", payload)
|
||||
|
||||
|
||||
@app.get("/vectors/collections")
|
||||
async def vectors_collections():
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _get_json(client, f"{QDRANT_SVC_URL}/collections")
|
||||
return await _get_json(f"{QDRANT_SVC_URL}/collections")
|
||||
|
||||
|
||||
@app.post("/vectors/collections")
|
||||
async def vectors_create_collection(payload: Dict[str, Any]):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/collections", payload)
|
||||
return await _post_json(f"{QDRANT_SVC_URL}/collections", payload)
|
||||
|
||||
|
||||
@app.get("/vectors/collections/{name}")
|
||||
async def vectors_collection_info(name: str):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _get_json(client, f"{QDRANT_SVC_URL}/collections/{name}")
|
||||
return await _get_json(f"{QDRANT_SVC_URL}/collections/{name}")
|
||||
|
||||
|
||||
@app.get("/vectors/inspect")
|
||||
async def vectors_inspect():
|
||||
"""Full diagnostic summary for all Qdrant collections (HNSW, optimizer, payload indexes, RAM estimate)."""
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _get_json(client, f"{QDRANT_SVC_URL}/inspect")
|
||||
t0 = time.perf_counter()
|
||||
logger.info("vectors_inspect: start")
|
||||
result = await _get_json(f"{QDRANT_SVC_URL}/inspect")
|
||||
logger.info("vectors_inspect: done elapsed_ms=%.1f", (time.perf_counter() - t0) * 1000)
|
||||
return result
|
||||
|
||||
|
||||
@app.delete("/vectors/collections/{name}")
|
||||
async def vectors_delete_collection(name: str):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
r = await client.delete(f"{QDRANT_SVC_URL}/collections/{name}")
|
||||
try:
|
||||
r = await get_http_client().delete(f"{QDRANT_SVC_URL}/collections/{name}")
|
||||
except httpx.RequestError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream request failed: {exc}")
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}")
|
||||
return r.json()
|
||||
@@ -310,20 +384,18 @@ async def vectors_delete_collection(name: str):
|
||||
|
||||
@app.get("/vectors/points/{point_id}")
|
||||
async def vectors_get_point(point_id: str, collection: Optional[str] = None):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
params = {}
|
||||
if collection:
|
||||
params["collection"] = collection
|
||||
return await _get_json(client, f"{QDRANT_SVC_URL}/points/{point_id}", params=params)
|
||||
return await _get_json(f"{QDRANT_SVC_URL}/points/{point_id}", params=params)
|
||||
|
||||
|
||||
@app.get("/vectors/points/by-original-id/{original_id}")
|
||||
async def vectors_get_point_by_original_id(original_id: str, collection: Optional[str] = None):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
params = {}
|
||||
if collection:
|
||||
params["collection"] = collection
|
||||
return await _get_json(client, f"{QDRANT_SVC_URL}/points/by-original-id/{original_id}", params=params)
|
||||
return await _get_json(f"{QDRANT_SVC_URL}/points/by-original-id/{original_id}", params=params)
|
||||
|
||||
|
||||
# ---- File-based universal analyze ----
|
||||
@@ -337,31 +409,61 @@ async def analyze_all_file(
|
||||
max_length: int = Form(60),
|
||||
):
|
||||
data = await file.read()
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
clip_task = _post_file(client, f"{CLIP_URL}/analyze/file", data, {"limit": limit})
|
||||
blip_task = _post_file(client, f"{BLIP_URL}/caption/file", data, {"variants": variants, "max_length": max_length})
|
||||
yolo_task = _post_file(client, f"{YOLO_URL}/detect/file", data, {"conf": conf})
|
||||
|
||||
clip_res, blip_res, yolo_res = await asyncio.gather(clip_task, blip_task, yolo_task)
|
||||
|
||||
clip_res, blip_res, yolo_res = await asyncio.gather(
|
||||
_post_file(f"{CLIP_URL}/analyze/file", data, {"limit": limit}),
|
||||
_post_file(f"{BLIP_URL}/caption/file", data, {"variants": variants, "max_length": max_length}),
|
||||
_post_file(f"{YOLO_URL}/detect/file", data, {"conf": conf}),
|
||||
)
|
||||
return {"clip": clip_res, "blip": blip_res, "yolo": yolo_res}
|
||||
|
||||
|
||||
# ---- Maturity / NSFW analysis endpoints ----
|
||||
|
||||
|
||||
def _assert_maturity_enabled() -> None:
|
||||
if not MATURITY_ENABLED:
|
||||
raise HTTPException(status_code=503, detail="Maturity service is disabled")
|
||||
|
||||
|
||||
@app.post("/analyze/maturity")
|
||||
async def analyze_maturity(req: MaturityRequest):
|
||||
"""Analyze an image URL for maturity / NSFW content.
|
||||
|
||||
Returns a normalized maturity signal including maturity_label (safe/mature),
|
||||
confidence, score, optional sublabels, and an action_hint for Nova moderation.
|
||||
"""
|
||||
_assert_maturity_enabled()
|
||||
if not req.url:
|
||||
raise HTTPException(status_code=400, detail="url is required")
|
||||
logger.info("analyze_maturity: url=%s", req.url)
|
||||
return await _post_json(f"{MATURITY_URL}/analyze", req.model_dump())
|
||||
|
||||
|
||||
@app.post("/analyze/maturity/file")
|
||||
async def analyze_maturity_file(file: UploadFile = File(...)):
|
||||
"""Analyze an uploaded image file for maturity / NSFW content.
|
||||
|
||||
Returns the same normalized maturity signal as /analyze/maturity.
|
||||
"""
|
||||
_assert_maturity_enabled()
|
||||
data = await file.read()
|
||||
logger.info("analyze_maturity_file: filename=%s size=%d", file.filename, len(data))
|
||||
return await _post_file(f"{MATURITY_URL}/analyze/file", data, {})
|
||||
|
||||
|
||||
# ---- Card renderer endpoints ----
|
||||
|
||||
@app.get("/cards/templates")
|
||||
async def cards_templates():
|
||||
"""List available card templates."""
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _get_json(client, f"{CARD_RENDERER_URL}/templates")
|
||||
return await _get_json(f"{CARD_RENDERER_URL}/templates")
|
||||
|
||||
|
||||
@app.post("/cards/render")
|
||||
async def cards_render(payload: Dict[str, Any]):
|
||||
"""Render a Nova card from a remote image URL. Returns binary image bytes."""
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
try:
|
||||
resp = await client.post(f"{CARD_RENDERER_URL}/render", json=payload)
|
||||
resp = await get_http_client().post(f"{CARD_RENDERER_URL}/render", json=payload)
|
||||
except httpx.RequestError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"card-renderer unreachable: {exc}")
|
||||
if resp.status_code >= 400:
|
||||
@@ -409,9 +511,8 @@ async def cards_render_file(
|
||||
fields["tags_json"] = tags_json
|
||||
|
||||
upload_files = {"file": (file.filename or "image", data, file.content_type or "application/octet-stream")}
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
resp = await get_http_client().post(
|
||||
f"{CARD_RENDERER_URL}/render/file",
|
||||
data={k: str(v) for k, v in fields.items()},
|
||||
files=upload_files,
|
||||
@@ -429,8 +530,7 @@ async def cards_render_file(
|
||||
@app.post("/cards/render/meta")
|
||||
async def cards_render_meta(payload: Dict[str, Any]):
|
||||
"""Return crop and layout metadata for a card render (no image produced)."""
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{CARD_RENDERER_URL}/render/meta", payload)
|
||||
return await _post_json(f"{CARD_RENDERER_URL}/render/meta", payload)
|
||||
|
||||
|
||||
# ---- Qdrant administration endpoints (index management + collection config) ----
|
||||
@@ -438,26 +538,22 @@ async def cards_render_meta(payload: Dict[str, Any]):
|
||||
@app.get("/vectors/collections/{name}/indexes")
|
||||
async def vectors_collection_indexes(name: str):
|
||||
"""List payload indexes for a collection."""
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _get_json(client, f"{QDRANT_SVC_URL}/collections/{name}/indexes")
|
||||
return await _get_json(f"{QDRANT_SVC_URL}/collections/{name}/indexes")
|
||||
|
||||
|
||||
@app.post("/vectors/collections/{name}/indexes")
|
||||
async def vectors_create_payload_index(name: str, payload: Dict[str, Any]):
|
||||
"""Create a payload index on a field in a collection."""
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/collections/{name}/indexes", payload)
|
||||
return await _post_json(f"{QDRANT_SVC_URL}/collections/{name}/indexes", payload)
|
||||
|
||||
|
||||
@app.post("/vectors/collections/{name}/ensure-indexes")
|
||||
async def vectors_ensure_indexes(name: str, payload: Dict[str, Any]):
|
||||
"""Idempotently ensure payload indexes exist for a list of fields."""
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/collections/{name}/ensure-indexes", payload)
|
||||
return await _post_json(f"{QDRANT_SVC_URL}/collections/{name}/ensure-indexes", payload)
|
||||
|
||||
|
||||
@app.post("/vectors/collections/{name}/configure")
|
||||
async def vectors_configure_collection(name: str, payload: Dict[str, Any]):
|
||||
"""Update HNSW and optimizer configuration for a collection."""
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/collections/{name}/configure", payload)
|
||||
return await _post_json(f"{QDRANT_SVC_URL}/collections/{name}/configure", payload)
|
||||
|
||||
17
maturity/Dockerfile
Normal file
17
maturity/Dockerfile
Normal file
@@ -0,0 +1,17 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY maturity/requirements.txt /app/requirements.txt
|
||||
RUN pip install --no-cache-dir -r /app/requirements.txt
|
||||
|
||||
COPY maturity /app
|
||||
COPY common /app/common
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
221
maturity/main.py
Normal file
221
maturity/main.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Skinbase Maturity Analysis Service.
|
||||
|
||||
Uses a dedicated NSFW/maturity ViT classifier (Falconsai/nsfw_image_detection)
|
||||
to produce a structured, normalized maturity signal consumable by Nova moderation
|
||||
workflows.
|
||||
|
||||
Endpoints
|
||||
---------
|
||||
GET /health — liveness + model info
|
||||
POST /analyze — URL-based maturity analysis
|
||||
POST /analyze/file — multipart file-upload maturity analysis
|
||||
|
||||
Response contract
|
||||
-----------------
|
||||
{
|
||||
"maturity_label": "safe" | "mature",
|
||||
"confidence": float, # confidence in the maturity_label decision
|
||||
"score": float, # raw NSFW probability from model (0.0 – 1.0)
|
||||
"labels": list[str], # sublabels when mature, e.g. ["nsfw"]
|
||||
"model": str, # model identifier
|
||||
"threshold_used": float, # threshold that produced the label
|
||||
"analysis_time_ms": float,
|
||||
"source": "maturity-service",
|
||||
"action_hint": "safe" | "review" | "flag_high",
|
||||
"advisory": str # short human-readable reason
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
||||
|
||||
from common.image_io import fetch_url_bytes, bytes_to_pil, ImageLoadError
|
||||
|
||||
logger = logging.getLogger("maturity")
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration — all tunable via environment variables
|
||||
# ---------------------------------------------------------------------------
|
||||
MATURITY_MODEL = os.getenv("MATURITY_MODEL", "Falconsai/nsfw_image_detection")
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# Main threshold: score >= this → "mature" + "flag_high"
|
||||
THRESHOLD_MATURE = float(os.getenv("MATURITY_THRESHOLD_MATURE", "0.80"))
|
||||
# Review band: score >= this (but below MATURE) → "mature" + "review"
|
||||
THRESHOLD_REVIEW = float(os.getenv("MATURITY_THRESHOLD_REVIEW", "0.60"))
|
||||
|
||||
# Max image bytes — same default as the rest of the stack (50 MB)
|
||||
MAX_IMAGE_BYTES = int(os.getenv("MAX_IMAGE_BYTES", str(50 * 1024 * 1024)))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model loading — done once at import time so Docker start captures it
|
||||
# ---------------------------------------------------------------------------
|
||||
logger.info("maturity service: loading model %s on %s", MATURITY_MODEL, DEVICE)
|
||||
_t_load = time.perf_counter()
|
||||
|
||||
_processor = AutoImageProcessor.from_pretrained(MATURITY_MODEL)
|
||||
_model = AutoModelForImageClassification.from_pretrained(MATURITY_MODEL).to(DEVICE).eval()
|
||||
|
||||
# Build a label→index map from the model config so we are not fragile to label
|
||||
# ordering changes.
|
||||
_ID2LABEL: dict[int, str] = _model.config.id2label # e.g. {0: "normal", 1: "nsfw"}
|
||||
_NSFW_IDX: int = next(
|
||||
(i for i, lbl in _ID2LABEL.items() if "nsfw" in lbl.lower() or "explicit" in lbl.lower()),
|
||||
1, # fallback: assume index 1 is the NSFW class
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"maturity service: model loaded elapsed_ms=%.1f device=%s id2label=%s nsfw_idx=%s",
|
||||
(time.perf_counter() - _t_load) * 1000,
|
||||
DEVICE,
|
||||
_ID2LABEL,
|
||||
_NSFW_IDX,
|
||||
)
|
||||
|
||||
app = FastAPI(title="Skinbase Maturity Service", version="1.0.0")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MaturityRequest(BaseModel):
|
||||
url: Optional[str] = Field(default=None, description="Public image URL to analyse")
|
||||
|
||||
|
||||
class MaturityResponse(BaseModel):
|
||||
maturity_label: str = Field(description='"safe" or "mature"')
|
||||
confidence: float = Field(description="Confidence in the maturity_label decision (0–1)")
|
||||
score: float = Field(description="Raw NSFW probability from the model (0–1)")
|
||||
labels: List[str] = Field(description="Sublabels when mature content is detected")
|
||||
model: str = Field(description="Model identifier / version")
|
||||
threshold_used: float = Field(description="Threshold applied to produce the label")
|
||||
analysis_time_ms: float
|
||||
source: str = "maturity-service"
|
||||
action_hint: str = Field(description='"safe", "review", or "flag_high"')
|
||||
advisory: str = Field(description="Short human-readable reason for the decision")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inference helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _run_inference(data: bytes) -> MaturityResponse:
|
||||
"""Run maturity inference on raw image bytes and return a structured response."""
|
||||
t0 = time.perf_counter()
|
||||
|
||||
try:
|
||||
img = bytes_to_pil(data)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=f"Cannot decode image: {exc}") from exc
|
||||
|
||||
inputs = _processor(images=img, return_tensors="pt").to(DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = _model(**inputs).logits
|
||||
|
||||
probs = torch.softmax(logits, dim=-1)[0]
|
||||
nsfw_score = float(probs[_NSFW_IDX])
|
||||
|
||||
elapsed_ms = (time.perf_counter() - t0) * 1000
|
||||
|
||||
# Derive label, action_hint, advisory, sublabels
|
||||
if nsfw_score >= THRESHOLD_MATURE:
|
||||
maturity_label = "mature"
|
||||
action_hint = "flag_high"
|
||||
advisory = "High-confidence mature content detected"
|
||||
labels = ["nsfw"]
|
||||
threshold_used = THRESHOLD_MATURE
|
||||
confidence = nsfw_score
|
||||
elif nsfw_score >= THRESHOLD_REVIEW:
|
||||
maturity_label = "mature"
|
||||
action_hint = "review"
|
||||
advisory = "Possible mature content — review recommended"
|
||||
labels = ["nsfw"]
|
||||
threshold_used = THRESHOLD_REVIEW
|
||||
confidence = nsfw_score
|
||||
else:
|
||||
maturity_label = "safe"
|
||||
action_hint = "safe"
|
||||
advisory = "Content appears safe"
|
||||
labels = []
|
||||
threshold_used = THRESHOLD_REVIEW
|
||||
confidence = 1.0 - nsfw_score # confidence in the "safe" verdict
|
||||
|
||||
logger.info(
|
||||
"maturity inference: maturity_label=%s action_hint=%s score=%.4f "
|
||||
"confidence=%.4f threshold_mature=%.2f threshold_review=%.2f elapsed_ms=%.1f",
|
||||
maturity_label,
|
||||
action_hint,
|
||||
nsfw_score,
|
||||
confidence,
|
||||
THRESHOLD_MATURE,
|
||||
THRESHOLD_REVIEW,
|
||||
elapsed_ms,
|
||||
)
|
||||
|
||||
return MaturityResponse(
|
||||
maturity_label=maturity_label,
|
||||
confidence=round(confidence, 4),
|
||||
score=round(nsfw_score, 4),
|
||||
labels=labels,
|
||||
model=MATURITY_MODEL,
|
||||
threshold_used=threshold_used,
|
||||
analysis_time_ms=round(elapsed_ms, 1),
|
||||
source="maturity-service",
|
||||
action_hint=action_hint,
|
||||
advisory=advisory,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {
|
||||
"status": "ok",
|
||||
"device": DEVICE,
|
||||
"model": MATURITY_MODEL,
|
||||
"threshold_mature": THRESHOLD_MATURE,
|
||||
"threshold_review": THRESHOLD_REVIEW,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/analyze", response_model=MaturityResponse)
|
||||
def analyze(req: MaturityRequest):
|
||||
"""URL-based maturity analysis."""
|
||||
if not req.url:
|
||||
raise HTTPException(status_code=400, detail="url is required")
|
||||
try:
|
||||
data = fetch_url_bytes(req.url, max_bytes=MAX_IMAGE_BYTES)
|
||||
except ImageLoadError as exc:
|
||||
logger.warning("maturity analyze: image load failed url=%s error=%s", req.url, exc)
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
return _run_inference(data)
|
||||
|
||||
|
||||
@app.post("/analyze/file", response_model=MaturityResponse)
|
||||
async def analyze_file(file: UploadFile = File(...)):
|
||||
"""Multipart file-upload maturity analysis."""
|
||||
data = await file.read()
|
||||
if len(data) > MAX_IMAGE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File exceeds maximum allowed size of {MAX_IMAGE_BYTES} bytes",
|
||||
)
|
||||
if not data:
|
||||
raise HTTPException(status_code=400, detail="Empty file upload")
|
||||
return _run_inference(data)
|
||||
8
maturity/requirements.txt
Normal file
8
maturity/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
fastapi==0.115.5
|
||||
uvicorn[standard]==0.30.6
|
||||
python-multipart==0.0.9
|
||||
requests==2.32.3
|
||||
pillow==10.4.0
|
||||
torch==2.4.1
|
||||
torchvision==0.19.1
|
||||
transformers==4.44.2
|
||||
Reference in New Issue
Block a user