Files
vision/clip/main.py

162 lines
5.8 KiB
Python

from __future__ import annotations
import os
from typing import List, Optional
import torch
import open_clip
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from pydantic import BaseModel, Field
import numpy as np
from common.image_io import fetch_url_bytes, bytes_to_pil, ImageLoadError
MODEL_NAME = os.getenv("MODEL_NAME", "ViT-B-32")
MODEL_PRETRAINED = os.getenv("MODEL_PRETRAINED", "openai")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Starter vocab (replace with DB-driven vocab later)
TAGS: List[str] = [
"wallpaper", "4k wallpaper", "8k wallpaper",
"cyberpunk", "neon", "city", "night", "sci-fi", "space",
"fantasy", "anime", "digital art", "abstract", "minimal",
"landscape", "nature", "mountains", "forest", "ocean", "sunset",
"photography", "portrait", "architecture", "cars", "gaming",
]
app = FastAPI(title="Skinbase CLIP Service", version="1.0.0")
model, _, preprocess = open_clip.create_model_and_transforms(MODEL_NAME, pretrained=MODEL_PRETRAINED)
tokenizer = open_clip.get_tokenizer(MODEL_NAME)
model = model.to(DEVICE).eval()
class AnalyzeRequest(BaseModel):
url: Optional[str] = None
limit: int = Field(default=5, ge=1, le=50)
threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
class EmbedRequest(BaseModel):
url: Optional[str] = None
backend: Optional[str] = Field(default="openclip", pattern="^(openclip|hf)$")
model: Optional[str] = None
pretrained: Optional[str] = None
@app.get("/health")
def health():
return {"status": "ok", "device": DEVICE, "model": MODEL_NAME, "pretrained": MODEL_PRETRAINED}
def _analyze_image_bytes(data: bytes, limit: int, threshold: Optional[float]):
img = bytes_to_pil(data)
image_input = preprocess(img).unsqueeze(0).to(DEVICE)
text = tokenizer(TAGS).to(DEVICE)
with torch.no_grad():
image_features = model.encode_image(image_input)
text_features = model.encode_text(text)
# Normalize so dot product approximates cosine similarity
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
logits = (image_features @ text_features.T)
probs = logits.softmax(dim=-1)
topk = probs[0].topk(min(limit, len(TAGS)))
results = []
for score, idx in zip(topk.values, topk.indices):
conf = float(score)
if threshold is not None and conf < float(threshold):
continue
results.append({"tag": TAGS[int(idx)], "confidence": conf})
return {"tags": results, "model": MODEL_NAME, "dim": int(text_features.shape[-1])}
def _embed_image_bytes(data: bytes, backend: str = "openclip", model_name: Optional[str] = None, pretrained: Optional[str] = None):
img = bytes_to_pil(data)
if backend == "openclip":
# prefer already-loaded model when matching global config
use_model_name = model_name or MODEL_NAME
use_pretrained = pretrained or MODEL_PRETRAINED
if use_model_name == MODEL_NAME and use_pretrained == MODEL_PRETRAINED:
_model = model
_preprocess = preprocess
device = DEVICE
else:
import open_clip as _oc
_model, _, _preprocess = _oc.create_model_and_transforms(use_model_name, pretrained=use_pretrained)
device = "cuda" if torch.cuda.is_available() else "cpu"
_model = _model.to(device).eval()
image_input = _preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
image_features = _model.encode_image(image_input)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
vec = image_features.cpu().numpy()[0]
else:
# HuggingFace CLIP backend
from transformers import CLIPProcessor, CLIPModel
hf_model_name = model_name or "openai/clip-vit-base-patch32"
device = "cuda" if torch.cuda.is_available() else "cpu"
hf_model = CLIPModel.from_pretrained(hf_model_name).to(device).eval()
processor = CLIPProcessor.from_pretrained(hf_model_name)
inputs = processor(images=img, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
feats = hf_model.get_image_features(**inputs)
feats = feats / feats.norm(dim=-1, keepdim=True)
vec = feats.cpu().numpy()[0]
return {"vector": vec.tolist(), "dim": int(np.asarray(vec).shape[-1]), "backend": backend, "model": model_name or (MODEL_NAME if backend == "openclip" else None)}
@app.post("/analyze")
def analyze(req: AnalyzeRequest):
if not req.url:
raise HTTPException(400, "url is required")
try:
data = fetch_url_bytes(req.url)
return _analyze_image_bytes(data, req.limit, req.threshold)
except ImageLoadError as e:
raise HTTPException(400, str(e))
@app.post("/analyze/file")
async def analyze_file(
file: UploadFile = File(...),
limit: int = Form(5),
threshold: Optional[float] = Form(None),
):
data = await file.read()
return _analyze_image_bytes(data, int(limit), threshold)
@app.post("/embed")
def embed(req: EmbedRequest):
if not req.url:
raise HTTPException(400, "url is required")
try:
data = fetch_url_bytes(req.url)
return _embed_image_bytes(data, backend=req.backend, model_name=req.model, pretrained=req.pretrained)
except ImageLoadError as e:
raise HTTPException(400, str(e))
@app.post("/embed/file")
async def embed_file(
file: UploadFile = File(...),
backend: str = Form("openclip"),
model: Optional[str] = Form(None),
pretrained: Optional[str] = Form(None),
):
data = await file.read()
return _embed_image_bytes(data, backend=backend, model_name=model, pretrained=pretrained)