Files
vision/gateway/main.py

306 lines
11 KiB
Python

from __future__ import annotations
import os
import asyncio
from typing import Any, Dict, Optional
import httpx
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from pydantic import BaseModel, Field
CLIP_URL = os.getenv("CLIP_URL", "http://clip:8000")
BLIP_URL = os.getenv("BLIP_URL", "http://blip:8000")
YOLO_URL = os.getenv("YOLO_URL", "http://yolo:8000")
QDRANT_SVC_URL = os.getenv("QDRANT_SVC_URL", "http://qdrant-svc:8000")
VISION_TIMEOUT = float(os.getenv("VISION_TIMEOUT", "20"))
# API key (set via env var `API_KEY`). If not set, gateway will reject requests.
API_KEY = os.getenv("API_KEY")
class APIKeyMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# allow health and docs endpoints without API key
if request.url.path in ("/health", "/openapi.json", "/docs", "/redoc"):
return await call_next(request)
key = request.headers.get("x-api-key") or request.headers.get("X-API-Key")
if not API_KEY or key != API_KEY:
return JSONResponse(status_code=401, content={"detail": "Unauthorized"})
return await call_next(request)
app = FastAPI(title="Skinbase Vision Gateway", version="1.0.0")
app.add_middleware(APIKeyMiddleware)
class ClipRequest(BaseModel):
url: Optional[str] = None
limit: int = Field(default=5, ge=1, le=50)
threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
class BlipRequest(BaseModel):
url: Optional[str] = None
variants: int = Field(default=3, ge=0, le=10)
max_length: int = Field(default=60, ge=10, le=200)
class YoloRequest(BaseModel):
url: Optional[str] = None
conf: float = Field(default=0.25, ge=0.0, le=1.0)
async def _get_health(client: httpx.AsyncClient, base: str) -> Dict[str, Any]:
try:
r = await client.get(f"{base}/health")
return r.json() if r.status_code == 200 else {"status": "bad", "code": r.status_code}
except Exception:
return {"status": "unreachable"}
async def _post_json(client: httpx.AsyncClient, url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
r = await client.post(url, json=payload)
if r.status_code >= 400:
raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:200]}")
return r.json()
async def _post_file(client: httpx.AsyncClient, url: str, data: bytes, fields: Dict[str, Any]) -> Dict[str, Any]:
files = {"file": ("image", data, "application/octet-stream")}
r = await client.post(url, data={k: str(v) for k, v in fields.items()}, files=files)
if r.status_code >= 400:
raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:200]}")
return r.json()
@app.get("/health")
async def health():
async with httpx.AsyncClient(timeout=5) as client:
clip_h, blip_h, yolo_h, qdrant_h = await asyncio.gather(
_get_health(client, CLIP_URL),
_get_health(client, BLIP_URL),
_get_health(client, YOLO_URL),
_get_health(client, QDRANT_SVC_URL),
)
return {"status": "ok", "services": {"clip": clip_h, "blip": blip_h, "yolo": yolo_h, "qdrant": qdrant_h}}
# ---- Individual analyze endpoints (URL) ----
@app.post("/analyze/clip")
async def analyze_clip(req: ClipRequest):
if not req.url:
raise HTTPException(400, "url is required")
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
return await _post_json(client, f"{CLIP_URL}/analyze", req.model_dump())
@app.post("/analyze/blip")
async def analyze_blip(req: BlipRequest):
if not req.url:
raise HTTPException(400, "url is required")
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
return await _post_json(client, f"{BLIP_URL}/caption", req.model_dump())
@app.post("/analyze/yolo")
async def analyze_yolo(req: YoloRequest):
if not req.url:
raise HTTPException(400, "url is required")
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
return await _post_json(client, f"{YOLO_URL}/detect", req.model_dump())
# ---- Individual analyze endpoints (file upload) ----
@app.post("/analyze/clip/file")
async def analyze_clip_file(
file: UploadFile = File(...),
limit: int = Form(5),
threshold: Optional[float] = Form(None),
):
data = await file.read()
fields: Dict[str, Any] = {"limit": int(limit)}
if threshold is not None:
fields["threshold"] = float(threshold)
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
return await _post_file(client, f"{CLIP_URL}/analyze/file", data, fields)
@app.post("/analyze/blip/file")
async def analyze_blip_file(
file: UploadFile = File(...),
variants: int = Form(3),
max_length: int = Form(60),
):
data = await file.read()
fields = {"variants": int(variants), "max_length": int(max_length)}
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
return await _post_file(client, f"{BLIP_URL}/caption/file", data, fields)
@app.post("/analyze/yolo/file")
async def analyze_yolo_file(
file: UploadFile = File(...),
conf: float = Form(0.25),
):
data = await file.read()
fields = {"conf": float(conf)}
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
return await _post_file(client, f"{YOLO_URL}/detect/file", data, fields)
@app.post("/analyze/all")
async def analyze_all(payload: Dict[str, Any]):
url = payload.get("url")
if not url:
raise HTTPException(400, "url is required")
clip_req = {"url": url, "limit": int(payload.get("limit", 5)), "threshold": payload.get("threshold")}
blip_req = {"url": url, "variants": int(payload.get("variants", 3)), "max_length": int(payload.get("max_length", 60))}
yolo_req = {"url": url, "conf": float(payload.get("conf", 0.25))}
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
clip_task = _post_json(client, f"{CLIP_URL}/analyze", clip_req)
blip_task = _post_json(client, f"{BLIP_URL}/caption", blip_req)
yolo_task = _post_json(client, f"{YOLO_URL}/detect", yolo_req)
clip_res, blip_res, yolo_res = await asyncio.gather(clip_task, blip_task, yolo_task)
return {"clip": clip_res, "blip": blip_res, "yolo": yolo_res}
# ---- Vector / Qdrant endpoints ----
@app.post("/vectors/upsert")
async def vectors_upsert(payload: Dict[str, Any]):
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
return await _post_json(client, f"{QDRANT_SVC_URL}/upsert", payload)
@app.post("/vectors/upsert/file")
async def vectors_upsert_file(
file: UploadFile = File(...),
id: Optional[str] = Form(None),
collection: Optional[str] = Form(None),
metadata_json: Optional[str] = Form(None),
):
data = await file.read()
fields: Dict[str, Any] = {}
if id is not None:
fields["id"] = id
if collection is not None:
fields["collection"] = collection
if metadata_json is not None:
fields["metadata_json"] = metadata_json
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
return await _post_file(client, f"{QDRANT_SVC_URL}/upsert/file", data, fields)
@app.post("/vectors/upsert/vector")
async def vectors_upsert_vector(payload: Dict[str, Any]):
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
return await _post_json(client, f"{QDRANT_SVC_URL}/upsert/vector", payload)
@app.post("/vectors/search")
async def vectors_search(payload: Dict[str, Any]):
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
return await _post_json(client, f"{QDRANT_SVC_URL}/search", payload)
@app.post("/vectors/search/file")
async def vectors_search_file(
file: UploadFile = File(...),
limit: int = Form(5),
score_threshold: Optional[float] = Form(None),
collection: Optional[str] = Form(None),
):
data = await file.read()
fields: Dict[str, Any] = {"limit": int(limit)}
if score_threshold is not None:
fields["score_threshold"] = float(score_threshold)
if collection is not None:
fields["collection"] = collection
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
return await _post_file(client, f"{QDRANT_SVC_URL}/search/file", data, fields)
@app.post("/vectors/search/vector")
async def vectors_search_vector(payload: Dict[str, Any]):
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
return await _post_json(client, f"{QDRANT_SVC_URL}/search/vector", payload)
@app.post("/vectors/delete")
async def vectors_delete(payload: Dict[str, Any]):
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
return await _post_json(client, f"{QDRANT_SVC_URL}/delete", payload)
@app.get("/vectors/collections")
async def vectors_collections():
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
r = await client.get(f"{QDRANT_SVC_URL}/collections")
if r.status_code >= 400:
raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}")
return r.json()
@app.post("/vectors/collections")
async def vectors_create_collection(payload: Dict[str, Any]):
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
return await _post_json(client, f"{QDRANT_SVC_URL}/collections", payload)
@app.get("/vectors/collections/{name}")
async def vectors_collection_info(name: str):
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
r = await client.get(f"{QDRANT_SVC_URL}/collections/{name}")
if r.status_code >= 400:
raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}")
return r.json()
@app.delete("/vectors/collections/{name}")
async def vectors_delete_collection(name: str):
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
r = await client.delete(f"{QDRANT_SVC_URL}/collections/{name}")
if r.status_code >= 400:
raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}")
return r.json()
@app.get("/vectors/points/{point_id}")
async def vectors_get_point(point_id: str, collection: Optional[str] = None):
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
params = {}
if collection:
params["collection"] = collection
r = await client.get(f"{QDRANT_SVC_URL}/points/{point_id}", params=params)
if r.status_code >= 400:
raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}")
return r.json()
# ---- File-based universal analyze ----
@app.post("/analyze/all/file")
async def analyze_all_file(
file: UploadFile = File(...),
limit: int = Form(5),
variants: int = Form(3),
conf: float = Form(0.25),
max_length: int = Form(60),
):
data = await file.read()
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
clip_task = _post_file(client, f"{CLIP_URL}/analyze/file", data, {"limit": limit})
blip_task = _post_file(client, f"{BLIP_URL}/caption/file", data, {"variants": variants, "max_length": max_length})
yolo_task = _post_file(client, f"{YOLO_URL}/detect/file", data, {"conf": conf})
clip_res, blip_res, yolo_res = await asyncio.gather(clip_task, blip_task, yolo_task)
return {"clip": clip_res, "blip": blip_res, "yolo": yolo_res}