"""Skinbase Maturity Analysis Service. Uses a dedicated NSFW/maturity ViT classifier (Falconsai/nsfw_image_detection) to produce a structured, normalized maturity signal consumable by Nova moderation workflows. Endpoints --------- GET /health — liveness + model info POST /analyze — URL-based maturity analysis POST /analyze/file — multipart file-upload maturity analysis Response contract ----------------- { "maturity_label": "safe" | "mature", "confidence": float, # confidence in the maturity_label decision "score": float, # raw NSFW probability from model (0.0 – 1.0) "labels": list[str], # sublabels when mature, e.g. ["nsfw"] "model": str, # model identifier "threshold_used": float, # threshold that produced the label "analysis_time_ms": float, "source": "maturity-service", "action_hint": "safe" | "review" | "flag_high", "advisory": str # short human-readable reason } """ from __future__ import annotations import logging import os import time from typing import List, Optional import torch from fastapi import FastAPI, HTTPException, UploadFile, File, Form from pydantic import BaseModel, Field from transformers import AutoImageProcessor, AutoModelForImageClassification from common.image_io import fetch_url_bytes, bytes_to_pil, ImageLoadError logger = logging.getLogger("maturity") logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s", ) # --------------------------------------------------------------------------- # Configuration — all tunable via environment variables # --------------------------------------------------------------------------- MATURITY_MODEL = os.getenv("MATURITY_MODEL", "Falconsai/nsfw_image_detection") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Main threshold: score >= this → "mature" + "flag_high" THRESHOLD_MATURE = float(os.getenv("MATURITY_THRESHOLD_MATURE", "0.80")) # Review band: score >= this (but below MATURE) → "mature" + "review" THRESHOLD_REVIEW = float(os.getenv("MATURITY_THRESHOLD_REVIEW", "0.60")) # Max image bytes — same default as the rest of the stack (50 MB) MAX_IMAGE_BYTES = int(os.getenv("MAX_IMAGE_BYTES", str(50 * 1024 * 1024))) # --------------------------------------------------------------------------- # Model loading — done once at import time so Docker start captures it # --------------------------------------------------------------------------- logger.info("maturity service: loading model %s on %s", MATURITY_MODEL, DEVICE) _t_load = time.perf_counter() _processor = AutoImageProcessor.from_pretrained(MATURITY_MODEL) _model = AutoModelForImageClassification.from_pretrained(MATURITY_MODEL).to(DEVICE).eval() # Build a label→index map from the model config so we are not fragile to label # ordering changes. _ID2LABEL: dict[int, str] = _model.config.id2label # e.g. {0: "normal", 1: "nsfw"} _NSFW_IDX: int = next( (i for i, lbl in _ID2LABEL.items() if "nsfw" in lbl.lower() or "explicit" in lbl.lower()), 1, # fallback: assume index 1 is the NSFW class ) logger.info( "maturity service: model loaded elapsed_ms=%.1f device=%s id2label=%s nsfw_idx=%s", (time.perf_counter() - _t_load) * 1000, DEVICE, _ID2LABEL, _NSFW_IDX, ) app = FastAPI(title="Skinbase Maturity Service", version="1.0.0") # --------------------------------------------------------------------------- # Schemas # --------------------------------------------------------------------------- class MaturityRequest(BaseModel): url: Optional[str] = Field(default=None, description="Public image URL to analyse") class MaturityResponse(BaseModel): maturity_label: str = Field(description='"safe" or "mature"') confidence: float = Field(description="Confidence in the maturity_label decision (0–1)") score: float = Field(description="Raw NSFW probability from the model (0–1)") labels: List[str] = Field(description="Sublabels when mature content is detected") model: str = Field(description="Model identifier / version") threshold_used: float = Field(description="Threshold applied to produce the label") analysis_time_ms: float source: str = "maturity-service" action_hint: str = Field(description='"safe", "review", or "flag_high"') advisory: str = Field(description="Short human-readable reason for the decision") # --------------------------------------------------------------------------- # Inference helper # --------------------------------------------------------------------------- def _run_inference(data: bytes) -> MaturityResponse: """Run maturity inference on raw image bytes and return a structured response.""" t0 = time.perf_counter() try: img = bytes_to_pil(data) except Exception as exc: raise HTTPException(status_code=400, detail=f"Cannot decode image: {exc}") from exc inputs = _processor(images=img, return_tensors="pt").to(DEVICE) with torch.no_grad(): logits = _model(**inputs).logits probs = torch.softmax(logits, dim=-1)[0] nsfw_score = float(probs[_NSFW_IDX]) elapsed_ms = (time.perf_counter() - t0) * 1000 # Derive label, action_hint, advisory, sublabels if nsfw_score >= THRESHOLD_MATURE: maturity_label = "mature" action_hint = "flag_high" advisory = "High-confidence mature content detected" labels = ["nsfw"] threshold_used = THRESHOLD_MATURE confidence = nsfw_score elif nsfw_score >= THRESHOLD_REVIEW: maturity_label = "mature" action_hint = "review" advisory = "Possible mature content — review recommended" labels = ["nsfw"] threshold_used = THRESHOLD_REVIEW confidence = nsfw_score else: maturity_label = "safe" action_hint = "safe" advisory = "Content appears safe" labels = [] threshold_used = THRESHOLD_REVIEW confidence = 1.0 - nsfw_score # confidence in the "safe" verdict logger.info( "maturity inference: maturity_label=%s action_hint=%s score=%.4f " "confidence=%.4f threshold_mature=%.2f threshold_review=%.2f elapsed_ms=%.1f", maturity_label, action_hint, nsfw_score, confidence, THRESHOLD_MATURE, THRESHOLD_REVIEW, elapsed_ms, ) return MaturityResponse( maturity_label=maturity_label, confidence=round(confidence, 4), score=round(nsfw_score, 4), labels=labels, model=MATURITY_MODEL, threshold_used=threshold_used, analysis_time_ms=round(elapsed_ms, 1), source="maturity-service", action_hint=action_hint, advisory=advisory, ) # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @app.get("/health") def health(): return { "status": "ok", "device": DEVICE, "model": MATURITY_MODEL, "threshold_mature": THRESHOLD_MATURE, "threshold_review": THRESHOLD_REVIEW, } @app.post("/analyze", response_model=MaturityResponse) def analyze(req: MaturityRequest): """URL-based maturity analysis.""" if not req.url: raise HTTPException(status_code=400, detail="url is required") try: data = fetch_url_bytes(req.url, max_bytes=MAX_IMAGE_BYTES) except ImageLoadError as exc: logger.warning("maturity analyze: image load failed url=%s error=%s", req.url, exc) raise HTTPException(status_code=400, detail=str(exc)) from exc return _run_inference(data) @app.post("/analyze/file", response_model=MaturityResponse) async def analyze_file(file: UploadFile = File(...)): """Multipart file-upload maturity analysis.""" data = await file.read() if len(data) > MAX_IMAGE_BYTES: raise HTTPException( status_code=413, detail=f"File exceeds maximum allowed size of {MAX_IMAGE_BYTES} bytes", ) if not data: raise HTTPException(status_code=400, detail="Empty file upload") return _run_inference(data)