Files
vision/blip/main.py
2026-03-21 09:09:28 +01:00

79 lines
2.4 KiB
Python

from __future__ import annotations
import os
from typing import Optional
import torch
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from pydantic import BaseModel, Field
from transformers import BlipProcessor, BlipForConditionalGeneration
from common.image_io import fetch_url_bytes, bytes_to_pil, ImageLoadError
BLIP_MODEL = os.getenv("BLIP_MODEL", "Salesforce/blip-image-captioning-base")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
app = FastAPI(title="Skinbase BLIP Service", version="1.0.0")
processor = BlipProcessor.from_pretrained(BLIP_MODEL)
model = BlipForConditionalGeneration.from_pretrained(BLIP_MODEL).to(DEVICE).eval()
class CaptionRequest(BaseModel):
url: Optional[str] = None
variants: int = Field(default=3, ge=0, le=10)
max_length: int = Field(default=60, ge=10, le=200)
@app.get("/health")
def health():
return {"status": "ok", "device": DEVICE, "model": BLIP_MODEL}
def _caption_bytes(data: bytes, variants: int, max_length: int):
img = bytes_to_pil(data)
inputs = processor(img, return_tensors="pt").to(DEVICE)
with torch.no_grad():
out = model.generate(**inputs, max_length=max_length, num_beams=5)
base_caption = processor.decode(out[0], skip_special_tokens=True)
variant_list = []
# generate additional variants using sampling (best-effort uniqueness)
for _ in range(max(0, variants - 1)):
with torch.no_grad():
out2 = model.generate(
**inputs,
max_length=max_length,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.9,
)
text = processor.decode(out2[0], skip_special_tokens=True)
if text != base_caption and text not in variant_list:
variant_list.append(text)
return {"caption": base_caption, "variants": variant_list, "model": BLIP_MODEL}
@app.post("/caption")
def caption(req: CaptionRequest):
if not req.url:
raise HTTPException(400, "url is required")
try:
data = fetch_url_bytes(req.url)
return _caption_bytes(data, req.variants, req.max_length)
except ImageLoadError as e:
raise HTTPException(400, str(e))
@app.post("/caption/file")
async def caption_file(
file: UploadFile = File(...),
variants: int = Form(3),
max_length: int = Form(60),
):
data = await file.read()
return _caption_bytes(data, int(variants), int(max_length))