first commit
This commit is contained in:
70
yolo/main.py
Normal file
70
yolo/main.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||||
from pydantic import BaseModel, Field
|
||||
from ultralytics import YOLO
|
||||
|
||||
from common.image_io import fetch_url_bytes, bytes_to_pil, ImageLoadError
|
||||
|
||||
YOLO_MODEL = os.getenv("YOLO_MODEL", "yolov8n.pt")
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
app = FastAPI(title="Skinbase YOLO Service", version="1.0.0")
|
||||
|
||||
model = YOLO(YOLO_MODEL)
|
||||
|
||||
|
||||
class DetectRequest(BaseModel):
|
||||
url: Optional[str] = None
|
||||
conf: float = Field(default=0.25, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "ok", "device": DEVICE, "model": YOLO_MODEL}
|
||||
|
||||
|
||||
def _detect_bytes(data: bytes, conf: float):
|
||||
img = bytes_to_pil(data)
|
||||
|
||||
results = model(img)
|
||||
|
||||
best: Dict[str, float] = {}
|
||||
for r in results:
|
||||
for box in r.boxes:
|
||||
score = float(box.conf[0])
|
||||
if score < conf:
|
||||
continue
|
||||
cls_id = int(box.cls[0])
|
||||
label = model.names.get(cls_id, str(cls_id))
|
||||
if label not in best or best[label] < score:
|
||||
best[label] = score
|
||||
|
||||
detections = [{"label": k, "confidence": v} for k, v in best.items()]
|
||||
detections.sort(key=lambda x: x["confidence"], reverse=True)
|
||||
|
||||
return {"detections": detections, "model": YOLO_MODEL}
|
||||
|
||||
|
||||
@app.post("/detect")
|
||||
def detect(req: DetectRequest):
|
||||
if not req.url:
|
||||
raise HTTPException(400, "url is required")
|
||||
try:
|
||||
data = fetch_url_bytes(req.url)
|
||||
return _detect_bytes(data, float(req.conf))
|
||||
except ImageLoadError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
|
||||
|
||||
@app.post("/detect/file")
|
||||
async def detect_file(
|
||||
file: UploadFile = File(...),
|
||||
conf: float = Form(0.25),
|
||||
):
|
||||
data = await file.read()
|
||||
return _detect_bytes(data, float(conf))
|
||||
Reference in New Issue
Block a user