first commit

This commit is contained in:
2026-03-21 09:09:28 +01:00
commit 8da669c0e1
23 changed files with 1812 additions and 0 deletions

78
blip/main.py Normal file
View File

@@ -0,0 +1,78 @@
from __future__ import annotations
import os
from typing import Optional
import torch
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from pydantic import BaseModel, Field
from transformers import BlipProcessor, BlipForConditionalGeneration
from common.image_io import fetch_url_bytes, bytes_to_pil, ImageLoadError
BLIP_MODEL = os.getenv("BLIP_MODEL", "Salesforce/blip-image-captioning-base")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
app = FastAPI(title="Skinbase BLIP Service", version="1.0.0")
processor = BlipProcessor.from_pretrained(BLIP_MODEL)
model = BlipForConditionalGeneration.from_pretrained(BLIP_MODEL).to(DEVICE).eval()
class CaptionRequest(BaseModel):
url: Optional[str] = None
variants: int = Field(default=3, ge=0, le=10)
max_length: int = Field(default=60, ge=10, le=200)
@app.get("/health")
def health():
return {"status": "ok", "device": DEVICE, "model": BLIP_MODEL}
def _caption_bytes(data: bytes, variants: int, max_length: int):
img = bytes_to_pil(data)
inputs = processor(img, return_tensors="pt").to(DEVICE)
with torch.no_grad():
out = model.generate(**inputs, max_length=max_length, num_beams=5)
base_caption = processor.decode(out[0], skip_special_tokens=True)
variant_list = []
# generate additional variants using sampling (best-effort uniqueness)
for _ in range(max(0, variants - 1)):
with torch.no_grad():
out2 = model.generate(
**inputs,
max_length=max_length,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.9,
)
text = processor.decode(out2[0], skip_special_tokens=True)
if text != base_caption and text not in variant_list:
variant_list.append(text)
return {"caption": base_caption, "variants": variant_list, "model": BLIP_MODEL}
@app.post("/caption")
def caption(req: CaptionRequest):
if not req.url:
raise HTTPException(400, "url is required")
try:
data = fetch_url_bytes(req.url)
return _caption_bytes(data, req.variants, req.max_length)
except ImageLoadError as e:
raise HTTPException(400, str(e))
@app.post("/caption/file")
async def caption_file(
file: UploadFile = File(...),
variants: int = Form(3),
max_length: int = Form(60),
):
data = await file.read()
return _caption_bytes(data, int(variants), int(max_length))