from __future__ import annotations import os from typing import Optional import torch from fastapi import FastAPI, HTTPException, UploadFile, File, Form from pydantic import BaseModel, Field from transformers import BlipProcessor, BlipForConditionalGeneration from common.image_io import fetch_url_bytes, bytes_to_pil, ImageLoadError BLIP_MODEL = os.getenv("BLIP_MODEL", "Salesforce/blip-image-captioning-base") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" app = FastAPI(title="Skinbase BLIP Service", version="1.0.0") processor = BlipProcessor.from_pretrained(BLIP_MODEL) model = BlipForConditionalGeneration.from_pretrained(BLIP_MODEL).to(DEVICE).eval() class CaptionRequest(BaseModel): url: Optional[str] = None variants: int = Field(default=3, ge=0, le=10) max_length: int = Field(default=60, ge=10, le=200) @app.get("/health") def health(): return {"status": "ok", "device": DEVICE, "model": BLIP_MODEL} def _caption_bytes(data: bytes, variants: int, max_length: int): img = bytes_to_pil(data) inputs = processor(img, return_tensors="pt").to(DEVICE) with torch.no_grad(): out = model.generate(**inputs, max_length=max_length, num_beams=5) base_caption = processor.decode(out[0], skip_special_tokens=True) variant_list = [] # generate additional variants using sampling (best-effort uniqueness) for _ in range(max(0, variants - 1)): with torch.no_grad(): out2 = model.generate( **inputs, max_length=max_length, do_sample=True, top_k=50, top_p=0.95, temperature=0.9, ) text = processor.decode(out2[0], skip_special_tokens=True) if text != base_caption and text not in variant_list: variant_list.append(text) return {"caption": base_caption, "variants": variant_list, "model": BLIP_MODEL} @app.post("/caption") def caption(req: CaptionRequest): if not req.url: raise HTTPException(400, "url is required") try: data = fetch_url_bytes(req.url) return _caption_bytes(data, req.variants, req.max_length) except ImageLoadError as e: raise HTTPException(400, str(e)) @app.post("/caption/file") async def caption_file( file: UploadFile = File(...), variants: int = Form(3), max_length: int = Form(60), ): data = await file.read() return _caption_bytes(data, int(variants), int(max_length))