from __future__ import annotations import os from typing import List, Optional import torch import open_clip from fastapi import FastAPI, HTTPException, UploadFile, File, Form from pydantic import BaseModel, Field import numpy as np from common.image_io import fetch_url_bytes, bytes_to_pil, ImageLoadError MODEL_NAME = os.getenv("MODEL_NAME", "ViT-B-32") MODEL_PRETRAINED = os.getenv("MODEL_PRETRAINED", "openai") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Starter vocab (replace with DB-driven vocab later) TAGS: List[str] = [ "wallpaper", "4k wallpaper", "8k wallpaper", "cyberpunk", "neon", "city", "night", "sci-fi", "space", "fantasy", "anime", "digital art", "abstract", "minimal", "landscape", "nature", "mountains", "forest", "ocean", "sunset", "photography", "portrait", "architecture", "cars", "gaming", ] app = FastAPI(title="Skinbase CLIP Service", version="1.0.0") model, _, preprocess = open_clip.create_model_and_transforms(MODEL_NAME, pretrained=MODEL_PRETRAINED) tokenizer = open_clip.get_tokenizer(MODEL_NAME) model = model.to(DEVICE).eval() class AnalyzeRequest(BaseModel): url: Optional[str] = None limit: int = Field(default=5, ge=1, le=50) threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0) class EmbedRequest(BaseModel): url: Optional[str] = None backend: Optional[str] = Field(default="openclip", pattern="^(openclip|hf)$") model: Optional[str] = None pretrained: Optional[str] = None @app.get("/health") def health(): return {"status": "ok", "device": DEVICE, "model": MODEL_NAME, "pretrained": MODEL_PRETRAINED} def _analyze_image_bytes(data: bytes, limit: int, threshold: Optional[float]): img = bytes_to_pil(data) image_input = preprocess(img).unsqueeze(0).to(DEVICE) text = tokenizer(TAGS).to(DEVICE) with torch.no_grad(): image_features = model.encode_image(image_input) text_features = model.encode_text(text) # Normalize so dot product approximates cosine similarity image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) logits = (image_features @ text_features.T) probs = logits.softmax(dim=-1) topk = probs[0].topk(min(limit, len(TAGS))) results = [] for score, idx in zip(topk.values, topk.indices): conf = float(score) if threshold is not None and conf < float(threshold): continue results.append({"tag": TAGS[int(idx)], "confidence": conf}) return {"tags": results, "model": MODEL_NAME, "dim": int(text_features.shape[-1])} def _embed_image_bytes(data: bytes, backend: str = "openclip", model_name: Optional[str] = None, pretrained: Optional[str] = None): img = bytes_to_pil(data) if backend == "openclip": # prefer already-loaded model when matching global config use_model_name = model_name or MODEL_NAME use_pretrained = pretrained or MODEL_PRETRAINED if use_model_name == MODEL_NAME and use_pretrained == MODEL_PRETRAINED: _model = model _preprocess = preprocess device = DEVICE else: import open_clip as _oc _model, _, _preprocess = _oc.create_model_and_transforms(use_model_name, pretrained=use_pretrained) device = "cuda" if torch.cuda.is_available() else "cpu" _model = _model.to(device).eval() image_input = _preprocess(img).unsqueeze(0).to(device) with torch.no_grad(): image_features = _model.encode_image(image_input) image_features = image_features / image_features.norm(dim=-1, keepdim=True) vec = image_features.cpu().numpy()[0] else: # HuggingFace CLIP backend from transformers import CLIPProcessor, CLIPModel hf_model_name = model_name or "openai/clip-vit-base-patch32" device = "cuda" if torch.cuda.is_available() else "cpu" hf_model = CLIPModel.from_pretrained(hf_model_name).to(device).eval() processor = CLIPProcessor.from_pretrained(hf_model_name) inputs = processor(images=img, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): feats = hf_model.get_image_features(**inputs) feats = feats / feats.norm(dim=-1, keepdim=True) vec = feats.cpu().numpy()[0] return {"vector": vec.tolist(), "dim": int(np.asarray(vec).shape[-1]), "backend": backend, "model": model_name or (MODEL_NAME if backend == "openclip" else None)} @app.post("/analyze") def analyze(req: AnalyzeRequest): if not req.url: raise HTTPException(400, "url is required") try: data = fetch_url_bytes(req.url) return _analyze_image_bytes(data, req.limit, req.threshold) except ImageLoadError as e: raise HTTPException(400, str(e)) @app.post("/analyze/file") async def analyze_file( file: UploadFile = File(...), limit: int = Form(5), threshold: Optional[float] = Form(None), ): data = await file.read() return _analyze_image_bytes(data, int(limit), threshold) @app.post("/embed") def embed(req: EmbedRequest): if not req.url: raise HTTPException(400, "url is required") try: data = fetch_url_bytes(req.url) return _embed_image_bytes(data, backend=req.backend, model_name=req.model, pretrained=req.pretrained) except ImageLoadError as e: raise HTTPException(400, str(e)) @app.post("/embed/file") async def embed_file( file: UploadFile = File(...), backend: str = Form("openclip"), model: Optional[str] = Form(None), pretrained: Optional[str] = Form(None), ): data = await file.read() return _embed_image_bytes(data, backend=backend, model_name=model, pretrained=pretrained)