- Add maturity/ service (FastAPI + Falconsai/nsfw_image_detection ViT classifier)
- /analyze (URL) and /analyze/file (multipart upload) endpoints
- Normalized response: maturity_label, confidence, score, labels,
action_hint (safe/review/flag_high), advisory, threshold_used,
analysis_time_ms, model, source
- Configurable thresholds via MATURITY_THRESHOLD_MATURE / MATURITY_THRESHOLD_REVIEW
- Reuses common/image_io for URL validation and file-size enforcement
- Explicit 502/503 errors on failure — no silent safe fallback
- Per-request structured logging (score, label, threshold path, elapsed ms)
- Update gateway/main.py
- MATURITY_URL + MATURITY_ENABLED env vars
- POST /analyze/maturity and POST /analyze/maturity/file endpoints
- /health includes maturity service status
- _assert_maturity_enabled() guard for clean 503 when disabled
- All existing endpoints untouched (additive change)
- Update docker-compose.yml
- Add maturity service with healthcheck (start_period: 90s)
- Gateway environment: MATURITY_URL, MATURITY_ENABLED
- Gateway depends_on: maturity (service_healthy)
- Update README.md and USAGE.md
- Document maturity service, env vars, curl examples,
full response schema table, action_hint logic, failure guidance
222 lines
8.1 KiB
Python
222 lines
8.1 KiB
Python
"""Skinbase Maturity Analysis Service.
|
||
|
||
Uses a dedicated NSFW/maturity ViT classifier (Falconsai/nsfw_image_detection)
|
||
to produce a structured, normalized maturity signal consumable by Nova moderation
|
||
workflows.
|
||
|
||
Endpoints
|
||
---------
|
||
GET /health — liveness + model info
|
||
POST /analyze — URL-based maturity analysis
|
||
POST /analyze/file — multipart file-upload maturity analysis
|
||
|
||
Response contract
|
||
-----------------
|
||
{
|
||
"maturity_label": "safe" | "mature",
|
||
"confidence": float, # confidence in the maturity_label decision
|
||
"score": float, # raw NSFW probability from model (0.0 – 1.0)
|
||
"labels": list[str], # sublabels when mature, e.g. ["nsfw"]
|
||
"model": str, # model identifier
|
||
"threshold_used": float, # threshold that produced the label
|
||
"analysis_time_ms": float,
|
||
"source": "maturity-service",
|
||
"action_hint": "safe" | "review" | "flag_high",
|
||
"advisory": str # short human-readable reason
|
||
}
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import os
|
||
import time
|
||
from typing import List, Optional
|
||
|
||
import torch
|
||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||
from pydantic import BaseModel, Field
|
||
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
||
|
||
from common.image_io import fetch_url_bytes, bytes_to_pil, ImageLoadError
|
||
|
||
logger = logging.getLogger("maturity")
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||
)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Configuration — all tunable via environment variables
|
||
# ---------------------------------------------------------------------------
|
||
MATURITY_MODEL = os.getenv("MATURITY_MODEL", "Falconsai/nsfw_image_detection")
|
||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
||
# Main threshold: score >= this → "mature" + "flag_high"
|
||
THRESHOLD_MATURE = float(os.getenv("MATURITY_THRESHOLD_MATURE", "0.80"))
|
||
# Review band: score >= this (but below MATURE) → "mature" + "review"
|
||
THRESHOLD_REVIEW = float(os.getenv("MATURITY_THRESHOLD_REVIEW", "0.60"))
|
||
|
||
# Max image bytes — same default as the rest of the stack (50 MB)
|
||
MAX_IMAGE_BYTES = int(os.getenv("MAX_IMAGE_BYTES", str(50 * 1024 * 1024)))
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Model loading — done once at import time so Docker start captures it
|
||
# ---------------------------------------------------------------------------
|
||
logger.info("maturity service: loading model %s on %s", MATURITY_MODEL, DEVICE)
|
||
_t_load = time.perf_counter()
|
||
|
||
_processor = AutoImageProcessor.from_pretrained(MATURITY_MODEL)
|
||
_model = AutoModelForImageClassification.from_pretrained(MATURITY_MODEL).to(DEVICE).eval()
|
||
|
||
# Build a label→index map from the model config so we are not fragile to label
|
||
# ordering changes.
|
||
_ID2LABEL: dict[int, str] = _model.config.id2label # e.g. {0: "normal", 1: "nsfw"}
|
||
_NSFW_IDX: int = next(
|
||
(i for i, lbl in _ID2LABEL.items() if "nsfw" in lbl.lower() or "explicit" in lbl.lower()),
|
||
1, # fallback: assume index 1 is the NSFW class
|
||
)
|
||
|
||
logger.info(
|
||
"maturity service: model loaded elapsed_ms=%.1f device=%s id2label=%s nsfw_idx=%s",
|
||
(time.perf_counter() - _t_load) * 1000,
|
||
DEVICE,
|
||
_ID2LABEL,
|
||
_NSFW_IDX,
|
||
)
|
||
|
||
app = FastAPI(title="Skinbase Maturity Service", version="1.0.0")
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Schemas
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class MaturityRequest(BaseModel):
|
||
url: Optional[str] = Field(default=None, description="Public image URL to analyse")
|
||
|
||
|
||
class MaturityResponse(BaseModel):
|
||
maturity_label: str = Field(description='"safe" or "mature"')
|
||
confidence: float = Field(description="Confidence in the maturity_label decision (0–1)")
|
||
score: float = Field(description="Raw NSFW probability from the model (0–1)")
|
||
labels: List[str] = Field(description="Sublabels when mature content is detected")
|
||
model: str = Field(description="Model identifier / version")
|
||
threshold_used: float = Field(description="Threshold applied to produce the label")
|
||
analysis_time_ms: float
|
||
source: str = "maturity-service"
|
||
action_hint: str = Field(description='"safe", "review", or "flag_high"')
|
||
advisory: str = Field(description="Short human-readable reason for the decision")
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Inference helper
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _run_inference(data: bytes) -> MaturityResponse:
|
||
"""Run maturity inference on raw image bytes and return a structured response."""
|
||
t0 = time.perf_counter()
|
||
|
||
try:
|
||
img = bytes_to_pil(data)
|
||
except Exception as exc:
|
||
raise HTTPException(status_code=400, detail=f"Cannot decode image: {exc}") from exc
|
||
|
||
inputs = _processor(images=img, return_tensors="pt").to(DEVICE)
|
||
|
||
with torch.no_grad():
|
||
logits = _model(**inputs).logits
|
||
|
||
probs = torch.softmax(logits, dim=-1)[0]
|
||
nsfw_score = float(probs[_NSFW_IDX])
|
||
|
||
elapsed_ms = (time.perf_counter() - t0) * 1000
|
||
|
||
# Derive label, action_hint, advisory, sublabels
|
||
if nsfw_score >= THRESHOLD_MATURE:
|
||
maturity_label = "mature"
|
||
action_hint = "flag_high"
|
||
advisory = "High-confidence mature content detected"
|
||
labels = ["nsfw"]
|
||
threshold_used = THRESHOLD_MATURE
|
||
confidence = nsfw_score
|
||
elif nsfw_score >= THRESHOLD_REVIEW:
|
||
maturity_label = "mature"
|
||
action_hint = "review"
|
||
advisory = "Possible mature content — review recommended"
|
||
labels = ["nsfw"]
|
||
threshold_used = THRESHOLD_REVIEW
|
||
confidence = nsfw_score
|
||
else:
|
||
maturity_label = "safe"
|
||
action_hint = "safe"
|
||
advisory = "Content appears safe"
|
||
labels = []
|
||
threshold_used = THRESHOLD_REVIEW
|
||
confidence = 1.0 - nsfw_score # confidence in the "safe" verdict
|
||
|
||
logger.info(
|
||
"maturity inference: maturity_label=%s action_hint=%s score=%.4f "
|
||
"confidence=%.4f threshold_mature=%.2f threshold_review=%.2f elapsed_ms=%.1f",
|
||
maturity_label,
|
||
action_hint,
|
||
nsfw_score,
|
||
confidence,
|
||
THRESHOLD_MATURE,
|
||
THRESHOLD_REVIEW,
|
||
elapsed_ms,
|
||
)
|
||
|
||
return MaturityResponse(
|
||
maturity_label=maturity_label,
|
||
confidence=round(confidence, 4),
|
||
score=round(nsfw_score, 4),
|
||
labels=labels,
|
||
model=MATURITY_MODEL,
|
||
threshold_used=threshold_used,
|
||
analysis_time_ms=round(elapsed_ms, 1),
|
||
source="maturity-service",
|
||
action_hint=action_hint,
|
||
advisory=advisory,
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@app.get("/health")
|
||
def health():
|
||
return {
|
||
"status": "ok",
|
||
"device": DEVICE,
|
||
"model": MATURITY_MODEL,
|
||
"threshold_mature": THRESHOLD_MATURE,
|
||
"threshold_review": THRESHOLD_REVIEW,
|
||
}
|
||
|
||
|
||
@app.post("/analyze", response_model=MaturityResponse)
|
||
def analyze(req: MaturityRequest):
|
||
"""URL-based maturity analysis."""
|
||
if not req.url:
|
||
raise HTTPException(status_code=400, detail="url is required")
|
||
try:
|
||
data = fetch_url_bytes(req.url, max_bytes=MAX_IMAGE_BYTES)
|
||
except ImageLoadError as exc:
|
||
logger.warning("maturity analyze: image load failed url=%s error=%s", req.url, exc)
|
||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||
return _run_inference(data)
|
||
|
||
|
||
@app.post("/analyze/file", response_model=MaturityResponse)
|
||
async def analyze_file(file: UploadFile = File(...)):
|
||
"""Multipart file-upload maturity analysis."""
|
||
data = await file.read()
|
||
if len(data) > MAX_IMAGE_BYTES:
|
||
raise HTTPException(
|
||
status_code=413,
|
||
detail=f"File exceeds maximum allowed size of {MAX_IMAGE_BYTES} bytes",
|
||
)
|
||
if not data:
|
||
raise HTTPException(status_code=400, detail="Empty file upload")
|
||
return _run_inference(data)
|