Files
vision/yolo/main.py
2026-03-21 09:09:28 +01:00

71 lines
1.8 KiB
Python

from __future__ import annotations
import os
from typing import Optional, Dict
import torch
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from pydantic import BaseModel, Field
from ultralytics import YOLO
from common.image_io import fetch_url_bytes, bytes_to_pil, ImageLoadError
YOLO_MODEL = os.getenv("YOLO_MODEL", "yolov8n.pt")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
app = FastAPI(title="Skinbase YOLO Service", version="1.0.0")
model = YOLO(YOLO_MODEL)
class DetectRequest(BaseModel):
url: Optional[str] = None
conf: float = Field(default=0.25, ge=0.0, le=1.0)
@app.get("/health")
def health():
return {"status": "ok", "device": DEVICE, "model": YOLO_MODEL}
def _detect_bytes(data: bytes, conf: float):
img = bytes_to_pil(data)
results = model(img)
best: Dict[str, float] = {}
for r in results:
for box in r.boxes:
score = float(box.conf[0])
if score < conf:
continue
cls_id = int(box.cls[0])
label = model.names.get(cls_id, str(cls_id))
if label not in best or best[label] < score:
best[label] = score
detections = [{"label": k, "confidence": v} for k, v in best.items()]
detections.sort(key=lambda x: x["confidence"], reverse=True)
return {"detections": detections, "model": YOLO_MODEL}
@app.post("/detect")
def detect(req: DetectRequest):
if not req.url:
raise HTTPException(400, "url is required")
try:
data = fetch_url_bytes(req.url)
return _detect_bytes(data, float(req.conf))
except ImageLoadError as e:
raise HTTPException(400, str(e))
@app.post("/detect/file")
async def detect_file(
file: UploadFile = File(...),
conf: float = Form(0.25),
):
data = await file.read()
return _detect_bytes(data, float(conf))