Files
vision/maturity/main.py
Gregor Klevze baf497b015 feat(maturity): add dedicated NSFW/maturity analysis service
- Add maturity/ service (FastAPI + Falconsai/nsfw_image_detection ViT classifier)
  - /analyze (URL) and /analyze/file (multipart upload) endpoints
  - Normalized response: maturity_label, confidence, score, labels,
    action_hint (safe/review/flag_high), advisory, threshold_used,
    analysis_time_ms, model, source
  - Configurable thresholds via MATURITY_THRESHOLD_MATURE / MATURITY_THRESHOLD_REVIEW
  - Reuses common/image_io for URL validation and file-size enforcement
  - Explicit 502/503 errors on failure — no silent safe fallback
  - Per-request structured logging (score, label, threshold path, elapsed ms)

- Update gateway/main.py
  - MATURITY_URL + MATURITY_ENABLED env vars
  - POST /analyze/maturity and POST /analyze/maturity/file endpoints
  - /health includes maturity service status
  - _assert_maturity_enabled() guard for clean 503 when disabled
  - All existing endpoints untouched (additive change)

- Update docker-compose.yml
  - Add maturity service with healthcheck (start_period: 90s)
  - Gateway environment: MATURITY_URL, MATURITY_ENABLED
  - Gateway depends_on: maturity (service_healthy)

- Update README.md and USAGE.md
  - Document maturity service, env vars, curl examples,
    full response schema table, action_hint logic, failure guidance
2026-04-11 17:29:26 +02:00

222 lines
8.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Skinbase Maturity Analysis Service.
Uses a dedicated NSFW/maturity ViT classifier (Falconsai/nsfw_image_detection)
to produce a structured, normalized maturity signal consumable by Nova moderation
workflows.
Endpoints
---------
GET /health — liveness + model info
POST /analyze — URL-based maturity analysis
POST /analyze/file — multipart file-upload maturity analysis
Response contract
-----------------
{
"maturity_label": "safe" | "mature",
"confidence": float, # confidence in the maturity_label decision
"score": float, # raw NSFW probability from model (0.0 1.0)
"labels": list[str], # sublabels when mature, e.g. ["nsfw"]
"model": str, # model identifier
"threshold_used": float, # threshold that produced the label
"analysis_time_ms": float,
"source": "maturity-service",
"action_hint": "safe" | "review" | "flag_high",
"advisory": str # short human-readable reason
}
"""
from __future__ import annotations
import logging
import os
import time
from typing import List, Optional
import torch
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from pydantic import BaseModel, Field
from transformers import AutoImageProcessor, AutoModelForImageClassification
from common.image_io import fetch_url_bytes, bytes_to_pil, ImageLoadError
logger = logging.getLogger("maturity")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
# ---------------------------------------------------------------------------
# Configuration — all tunable via environment variables
# ---------------------------------------------------------------------------
MATURITY_MODEL = os.getenv("MATURITY_MODEL", "Falconsai/nsfw_image_detection")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Main threshold: score >= this → "mature" + "flag_high"
THRESHOLD_MATURE = float(os.getenv("MATURITY_THRESHOLD_MATURE", "0.80"))
# Review band: score >= this (but below MATURE) → "mature" + "review"
THRESHOLD_REVIEW = float(os.getenv("MATURITY_THRESHOLD_REVIEW", "0.60"))
# Max image bytes — same default as the rest of the stack (50 MB)
MAX_IMAGE_BYTES = int(os.getenv("MAX_IMAGE_BYTES", str(50 * 1024 * 1024)))
# ---------------------------------------------------------------------------
# Model loading — done once at import time so Docker start captures it
# ---------------------------------------------------------------------------
logger.info("maturity service: loading model %s on %s", MATURITY_MODEL, DEVICE)
_t_load = time.perf_counter()
_processor = AutoImageProcessor.from_pretrained(MATURITY_MODEL)
_model = AutoModelForImageClassification.from_pretrained(MATURITY_MODEL).to(DEVICE).eval()
# Build a label→index map from the model config so we are not fragile to label
# ordering changes.
_ID2LABEL: dict[int, str] = _model.config.id2label # e.g. {0: "normal", 1: "nsfw"}
_NSFW_IDX: int = next(
(i for i, lbl in _ID2LABEL.items() if "nsfw" in lbl.lower() or "explicit" in lbl.lower()),
1, # fallback: assume index 1 is the NSFW class
)
logger.info(
"maturity service: model loaded elapsed_ms=%.1f device=%s id2label=%s nsfw_idx=%s",
(time.perf_counter() - _t_load) * 1000,
DEVICE,
_ID2LABEL,
_NSFW_IDX,
)
app = FastAPI(title="Skinbase Maturity Service", version="1.0.0")
# ---------------------------------------------------------------------------
# Schemas
# ---------------------------------------------------------------------------
class MaturityRequest(BaseModel):
url: Optional[str] = Field(default=None, description="Public image URL to analyse")
class MaturityResponse(BaseModel):
maturity_label: str = Field(description='"safe" or "mature"')
confidence: float = Field(description="Confidence in the maturity_label decision (01)")
score: float = Field(description="Raw NSFW probability from the model (01)")
labels: List[str] = Field(description="Sublabels when mature content is detected")
model: str = Field(description="Model identifier / version")
threshold_used: float = Field(description="Threshold applied to produce the label")
analysis_time_ms: float
source: str = "maturity-service"
action_hint: str = Field(description='"safe", "review", or "flag_high"')
advisory: str = Field(description="Short human-readable reason for the decision")
# ---------------------------------------------------------------------------
# Inference helper
# ---------------------------------------------------------------------------
def _run_inference(data: bytes) -> MaturityResponse:
"""Run maturity inference on raw image bytes and return a structured response."""
t0 = time.perf_counter()
try:
img = bytes_to_pil(data)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Cannot decode image: {exc}") from exc
inputs = _processor(images=img, return_tensors="pt").to(DEVICE)
with torch.no_grad():
logits = _model(**inputs).logits
probs = torch.softmax(logits, dim=-1)[0]
nsfw_score = float(probs[_NSFW_IDX])
elapsed_ms = (time.perf_counter() - t0) * 1000
# Derive label, action_hint, advisory, sublabels
if nsfw_score >= THRESHOLD_MATURE:
maturity_label = "mature"
action_hint = "flag_high"
advisory = "High-confidence mature content detected"
labels = ["nsfw"]
threshold_used = THRESHOLD_MATURE
confidence = nsfw_score
elif nsfw_score >= THRESHOLD_REVIEW:
maturity_label = "mature"
action_hint = "review"
advisory = "Possible mature content — review recommended"
labels = ["nsfw"]
threshold_used = THRESHOLD_REVIEW
confidence = nsfw_score
else:
maturity_label = "safe"
action_hint = "safe"
advisory = "Content appears safe"
labels = []
threshold_used = THRESHOLD_REVIEW
confidence = 1.0 - nsfw_score # confidence in the "safe" verdict
logger.info(
"maturity inference: maturity_label=%s action_hint=%s score=%.4f "
"confidence=%.4f threshold_mature=%.2f threshold_review=%.2f elapsed_ms=%.1f",
maturity_label,
action_hint,
nsfw_score,
confidence,
THRESHOLD_MATURE,
THRESHOLD_REVIEW,
elapsed_ms,
)
return MaturityResponse(
maturity_label=maturity_label,
confidence=round(confidence, 4),
score=round(nsfw_score, 4),
labels=labels,
model=MATURITY_MODEL,
threshold_used=threshold_used,
analysis_time_ms=round(elapsed_ms, 1),
source="maturity-service",
action_hint=action_hint,
advisory=advisory,
)
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/health")
def health():
return {
"status": "ok",
"device": DEVICE,
"model": MATURITY_MODEL,
"threshold_mature": THRESHOLD_MATURE,
"threshold_review": THRESHOLD_REVIEW,
}
@app.post("/analyze", response_model=MaturityResponse)
def analyze(req: MaturityRequest):
"""URL-based maturity analysis."""
if not req.url:
raise HTTPException(status_code=400, detail="url is required")
try:
data = fetch_url_bytes(req.url, max_bytes=MAX_IMAGE_BYTES)
except ImageLoadError as exc:
logger.warning("maturity analyze: image load failed url=%s error=%s", req.url, exc)
raise HTTPException(status_code=400, detail=str(exc)) from exc
return _run_inference(data)
@app.post("/analyze/file", response_model=MaturityResponse)
async def analyze_file(file: UploadFile = File(...)):
"""Multipart file-upload maturity analysis."""
data = await file.read()
if len(data) > MAX_IMAGE_BYTES:
raise HTTPException(
status_code=413,
detail=f"File exceeds maximum allowed size of {MAX_IMAGE_BYTES} bytes",
)
if not data:
raise HTTPException(status_code=400, detail="Empty file upload")
return _run_inference(data)