319 lines
12 KiB
Python
319 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import asyncio
|
|
from typing import Any, Dict, Optional
|
|
|
|
import httpx
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Request
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from pydantic import BaseModel, Field
|
|
|
|
CLIP_URL = os.getenv("CLIP_URL", "http://clip:8000")
|
|
BLIP_URL = os.getenv("BLIP_URL", "http://blip:8000")
|
|
YOLO_URL = os.getenv("YOLO_URL", "http://yolo:8000")
|
|
QDRANT_SVC_URL = os.getenv("QDRANT_SVC_URL", "http://qdrant-svc:8000")
|
|
VISION_TIMEOUT = float(os.getenv("VISION_TIMEOUT", "20"))
|
|
|
|
# API key (set via env var `API_KEY`). If not set, gateway will reject requests.
|
|
API_KEY = os.getenv("API_KEY")
|
|
|
|
class APIKeyMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
# allow health and docs endpoints without API key
|
|
if request.url.path in ("/health", "/openapi.json", "/docs", "/redoc"):
|
|
return await call_next(request)
|
|
key = request.headers.get("x-api-key") or request.headers.get("X-API-Key")
|
|
if not API_KEY or key != API_KEY:
|
|
return JSONResponse(status_code=401, content={"detail": "Unauthorized"})
|
|
return await call_next(request)
|
|
|
|
app = FastAPI(title="Skinbase Vision Gateway", version="1.0.0")
|
|
app.add_middleware(APIKeyMiddleware)
|
|
|
|
|
|
class ClipRequest(BaseModel):
|
|
url: Optional[str] = None
|
|
limit: int = Field(default=5, ge=1, le=50)
|
|
threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
|
|
|
|
|
class BlipRequest(BaseModel):
|
|
url: Optional[str] = None
|
|
variants: int = Field(default=3, ge=0, le=10)
|
|
max_length: int = Field(default=60, ge=10, le=200)
|
|
|
|
|
|
class YoloRequest(BaseModel):
|
|
url: Optional[str] = None
|
|
conf: float = Field(default=0.25, ge=0.0, le=1.0)
|
|
|
|
|
|
async def _get_health(client: httpx.AsyncClient, base: str) -> Dict[str, Any]:
|
|
try:
|
|
r = await client.get(f"{base}/health")
|
|
return r.json() if r.status_code == 200 else {"status": "bad", "code": r.status_code}
|
|
except Exception:
|
|
return {"status": "unreachable"}
|
|
|
|
|
|
async def _post_json(client: httpx.AsyncClient, url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
try:
|
|
r = await client.post(url, json=payload)
|
|
except httpx.RequestError as e:
|
|
raise HTTPException(status_code=502, detail=f"Upstream request failed {url}: {str(e)}")
|
|
if r.status_code >= 400:
|
|
raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:1000]}")
|
|
try:
|
|
return r.json()
|
|
except Exception:
|
|
# upstream returned non-JSON (HTML error page or empty body)
|
|
raise HTTPException(status_code=502, detail=f"Upstream returned non-JSON at {url}: {r.status_code} {r.text[:1000]}")
|
|
|
|
|
|
async def _post_file(client: httpx.AsyncClient, url: str, data: bytes, fields: Dict[str, Any]) -> Dict[str, Any]:
|
|
files = {"file": ("image", data, "application/octet-stream")}
|
|
try:
|
|
r = await client.post(url, data={k: str(v) for k, v in fields.items()}, files=files)
|
|
except httpx.RequestError as e:
|
|
raise HTTPException(status_code=502, detail=f"Upstream request failed {url}: {str(e)}")
|
|
if r.status_code >= 400:
|
|
raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:1000]}")
|
|
try:
|
|
return r.json()
|
|
except Exception:
|
|
raise HTTPException(status_code=502, detail=f"Upstream returned non-JSON at {url}: {r.status_code} {r.text[:1000]}")
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
async with httpx.AsyncClient(timeout=5) as client:
|
|
clip_h, blip_h, yolo_h, qdrant_h = await asyncio.gather(
|
|
_get_health(client, CLIP_URL),
|
|
_get_health(client, BLIP_URL),
|
|
_get_health(client, YOLO_URL),
|
|
_get_health(client, QDRANT_SVC_URL),
|
|
)
|
|
return {"status": "ok", "services": {"clip": clip_h, "blip": blip_h, "yolo": yolo_h, "qdrant": qdrant_h}}
|
|
|
|
|
|
# ---- Individual analyze endpoints (URL) ----
|
|
|
|
@app.post("/analyze/clip")
|
|
async def analyze_clip(req: ClipRequest):
|
|
if not req.url:
|
|
raise HTTPException(400, "url is required")
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
return await _post_json(client, f"{CLIP_URL}/analyze", req.model_dump())
|
|
|
|
|
|
@app.post("/analyze/blip")
|
|
async def analyze_blip(req: BlipRequest):
|
|
if not req.url:
|
|
raise HTTPException(400, "url is required")
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
return await _post_json(client, f"{BLIP_URL}/caption", req.model_dump())
|
|
|
|
|
|
@app.post("/analyze/yolo")
|
|
async def analyze_yolo(req: YoloRequest):
|
|
if not req.url:
|
|
raise HTTPException(400, "url is required")
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
return await _post_json(client, f"{YOLO_URL}/detect", req.model_dump())
|
|
|
|
|
|
# ---- Individual analyze endpoints (file upload) ----
|
|
|
|
|
|
@app.post("/analyze/clip/file")
|
|
async def analyze_clip_file(
|
|
file: UploadFile = File(...),
|
|
limit: int = Form(5),
|
|
threshold: Optional[float] = Form(None),
|
|
):
|
|
data = await file.read()
|
|
fields: Dict[str, Any] = {"limit": int(limit)}
|
|
if threshold is not None:
|
|
fields["threshold"] = float(threshold)
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
return await _post_file(client, f"{CLIP_URL}/analyze/file", data, fields)
|
|
|
|
|
|
@app.post("/analyze/blip/file")
|
|
async def analyze_blip_file(
|
|
file: UploadFile = File(...),
|
|
variants: int = Form(3),
|
|
max_length: int = Form(60),
|
|
):
|
|
data = await file.read()
|
|
fields = {"variants": int(variants), "max_length": int(max_length)}
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
return await _post_file(client, f"{BLIP_URL}/caption/file", data, fields)
|
|
|
|
|
|
@app.post("/analyze/yolo/file")
|
|
async def analyze_yolo_file(
|
|
file: UploadFile = File(...),
|
|
conf: float = Form(0.25),
|
|
):
|
|
data = await file.read()
|
|
fields = {"conf": float(conf)}
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
return await _post_file(client, f"{YOLO_URL}/detect/file", data, fields)
|
|
|
|
|
|
@app.post("/analyze/all")
|
|
async def analyze_all(payload: Dict[str, Any]):
|
|
url = payload.get("url")
|
|
if not url:
|
|
raise HTTPException(400, "url is required")
|
|
|
|
clip_req = {"url": url, "limit": int(payload.get("limit", 5)), "threshold": payload.get("threshold")}
|
|
blip_req = {"url": url, "variants": int(payload.get("variants", 3)), "max_length": int(payload.get("max_length", 60))}
|
|
yolo_req = {"url": url, "conf": float(payload.get("conf", 0.25))}
|
|
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
clip_task = _post_json(client, f"{CLIP_URL}/analyze", clip_req)
|
|
blip_task = _post_json(client, f"{BLIP_URL}/caption", blip_req)
|
|
yolo_task = _post_json(client, f"{YOLO_URL}/detect", yolo_req)
|
|
|
|
clip_res, blip_res, yolo_res = await asyncio.gather(clip_task, blip_task, yolo_task)
|
|
|
|
return {"clip": clip_res, "blip": blip_res, "yolo": yolo_res}
|
|
|
|
|
|
# ---- Vector / Qdrant endpoints ----
|
|
|
|
@app.post("/vectors/upsert")
|
|
async def vectors_upsert(payload: Dict[str, Any]):
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
return await _post_json(client, f"{QDRANT_SVC_URL}/upsert", payload)
|
|
|
|
|
|
@app.post("/vectors/upsert/file")
|
|
async def vectors_upsert_file(
|
|
file: UploadFile = File(...),
|
|
id: Optional[str] = Form(None),
|
|
collection: Optional[str] = Form(None),
|
|
metadata_json: Optional[str] = Form(None),
|
|
):
|
|
data = await file.read()
|
|
fields: Dict[str, Any] = {}
|
|
if id is not None:
|
|
fields["id"] = id
|
|
if collection is not None:
|
|
fields["collection"] = collection
|
|
if metadata_json is not None:
|
|
fields["metadata_json"] = metadata_json
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
return await _post_file(client, f"{QDRANT_SVC_URL}/upsert/file", data, fields)
|
|
|
|
|
|
@app.post("/vectors/upsert/vector")
|
|
async def vectors_upsert_vector(payload: Dict[str, Any]):
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
return await _post_json(client, f"{QDRANT_SVC_URL}/upsert/vector", payload)
|
|
|
|
|
|
@app.post("/vectors/search")
|
|
async def vectors_search(payload: Dict[str, Any]):
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
return await _post_json(client, f"{QDRANT_SVC_URL}/search", payload)
|
|
|
|
|
|
@app.post("/vectors/search/file")
|
|
async def vectors_search_file(
|
|
file: UploadFile = File(...),
|
|
limit: int = Form(5),
|
|
score_threshold: Optional[float] = Form(None),
|
|
collection: Optional[str] = Form(None),
|
|
):
|
|
data = await file.read()
|
|
fields: Dict[str, Any] = {"limit": int(limit)}
|
|
if score_threshold is not None:
|
|
fields["score_threshold"] = float(score_threshold)
|
|
if collection is not None:
|
|
fields["collection"] = collection
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
return await _post_file(client, f"{QDRANT_SVC_URL}/search/file", data, fields)
|
|
|
|
|
|
@app.post("/vectors/search/vector")
|
|
async def vectors_search_vector(payload: Dict[str, Any]):
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
return await _post_json(client, f"{QDRANT_SVC_URL}/search/vector", payload)
|
|
|
|
|
|
@app.post("/vectors/delete")
|
|
async def vectors_delete(payload: Dict[str, Any]):
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
return await _post_json(client, f"{QDRANT_SVC_URL}/delete", payload)
|
|
|
|
|
|
@app.get("/vectors/collections")
|
|
async def vectors_collections():
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
r = await client.get(f"{QDRANT_SVC_URL}/collections")
|
|
if r.status_code >= 400:
|
|
raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}")
|
|
return r.json()
|
|
|
|
|
|
@app.post("/vectors/collections")
|
|
async def vectors_create_collection(payload: Dict[str, Any]):
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
return await _post_json(client, f"{QDRANT_SVC_URL}/collections", payload)
|
|
|
|
|
|
@app.get("/vectors/collections/{name}")
|
|
async def vectors_collection_info(name: str):
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
r = await client.get(f"{QDRANT_SVC_URL}/collections/{name}")
|
|
if r.status_code >= 400:
|
|
raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}")
|
|
return r.json()
|
|
|
|
|
|
@app.delete("/vectors/collections/{name}")
|
|
async def vectors_delete_collection(name: str):
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
r = await client.delete(f"{QDRANT_SVC_URL}/collections/{name}")
|
|
if r.status_code >= 400:
|
|
raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}")
|
|
return r.json()
|
|
|
|
|
|
@app.get("/vectors/points/{point_id}")
|
|
async def vectors_get_point(point_id: str, collection: Optional[str] = None):
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
params = {}
|
|
if collection:
|
|
params["collection"] = collection
|
|
r = await client.get(f"{QDRANT_SVC_URL}/points/{point_id}", params=params)
|
|
if r.status_code >= 400:
|
|
raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}")
|
|
return r.json()
|
|
|
|
|
|
# ---- File-based universal analyze ----
|
|
|
|
@app.post("/analyze/all/file")
|
|
async def analyze_all_file(
|
|
file: UploadFile = File(...),
|
|
limit: int = Form(5),
|
|
variants: int = Form(3),
|
|
conf: float = Form(0.25),
|
|
max_length: int = Form(60),
|
|
):
|
|
data = await file.read()
|
|
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
|
clip_task = _post_file(client, f"{CLIP_URL}/analyze/file", data, {"limit": limit})
|
|
blip_task = _post_file(client, f"{BLIP_URL}/caption/file", data, {"variants": variants, "max_length": max_length})
|
|
yolo_task = _post_file(client, f"{YOLO_URL}/detect/file", data, {"conf": conf})
|
|
|
|
clip_res, blip_res, yolo_res = await asyncio.gather(clip_task, blip_task, yolo_task)
|
|
|
|
return {"clip": clip_res, "blip": blip_res, "yolo": yolo_res}
|