Files
vision/qdrant/main.py

371 lines
12 KiB
Python

from __future__ import annotations
import os
import uuid
from typing import Any, Dict, List, Optional
import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from pydantic import BaseModel, Field
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
PointStruct,
VectorParams,
Filter,
FieldCondition,
MatchValue,
)
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant")
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
CLIP_URL = os.getenv("CLIP_URL", "http://clip:8000")
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "images")
VECTOR_DIM = int(os.getenv("VECTOR_DIM", "512"))
app = FastAPI(title="Skinbase Qdrant Service", version="1.0.0")
client: QdrantClient = None # type: ignore[assignment]
# ---------------------------------------------------------------------------
# Startup / shutdown
# ---------------------------------------------------------------------------
@app.on_event("startup")
def startup():
global client
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
_ensure_collection()
def _ensure_collection():
"""Create the default collection if it does not exist yet."""
collections = [c.name for c in client.get_collections().collections]
if COLLECTION_NAME not in collections:
client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=VECTOR_DIM, distance=Distance.COSINE),
)
# ---------------------------------------------------------------------------
# Request / Response models
# ---------------------------------------------------------------------------
class UpsertUrlRequest(BaseModel):
url: str
id: Optional[str] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
collection: Optional[str] = None
class UpsertVectorRequest(BaseModel):
vector: List[float]
id: Optional[str] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
collection: Optional[str] = None
class SearchUrlRequest(BaseModel):
url: str
limit: int = Field(default=5, ge=1, le=100)
score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
collection: Optional[str] = None
filter_metadata: Dict[str, Any] = Field(default_factory=dict)
class SearchVectorRequest(BaseModel):
vector: List[float]
limit: int = Field(default=5, ge=1, le=100)
score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
collection: Optional[str] = None
filter_metadata: Dict[str, Any] = Field(default_factory=dict)
class DeleteRequest(BaseModel):
ids: List[str]
collection: Optional[str] = None
class CollectionRequest(BaseModel):
name: str
vector_dim: int = Field(default=512, ge=1)
distance: str = Field(default="cosine")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _col(name: Optional[str]) -> str:
return name or COLLECTION_NAME
async def _embed_url(url: str) -> List[float]:
"""Call the CLIP service to get an image embedding."""
async with httpx.AsyncClient(timeout=30) as http:
try:
r = await http.post(f"{CLIP_URL}/embed", json={"url": url})
except httpx.RequestError as e:
raise HTTPException(502, f"CLIP request failed: {str(e)}")
if r.status_code >= 400:
raise HTTPException(502, f"CLIP /embed error: {r.status_code} {r.text[:200]}")
try:
return r.json()["vector"]
except Exception:
raise HTTPException(502, f"CLIP /embed returned non-JSON: {r.status_code} {r.text[:200]}")
async def _embed_bytes(data: bytes) -> List[float]:
"""Call the CLIP service to embed uploaded file bytes."""
async with httpx.AsyncClient(timeout=30) as http:
files = {"file": ("image", data, "application/octet-stream")}
try:
r = await http.post(f"{CLIP_URL}/embed/file", files=files)
except httpx.RequestError as e:
raise HTTPException(502, f"CLIP request failed: {str(e)}")
if r.status_code >= 400:
raise HTTPException(502, f"CLIP /embed/file error: {r.status_code} {r.text[:200]}")
try:
return r.json()["vector"]
except Exception:
raise HTTPException(502, f"CLIP /embed/file returned non-JSON: {r.status_code} {r.text[:200]}")
def _build_filter(metadata: Dict[str, Any]) -> Optional[Filter]:
if not metadata:
return None
conditions = [
FieldCondition(key=k, match=MatchValue(value=v))
for k, v in metadata.items()
]
return Filter(must=conditions)
def _point_id(raw: Optional[str]) -> str:
return raw or uuid.uuid4().hex
# ---------------------------------------------------------------------------
# Health
# ---------------------------------------------------------------------------
@app.get("/health")
def health():
try:
info = client.get_collections()
names = [c.name for c in info.collections]
return {"status": "ok", "qdrant": QDRANT_HOST, "collections": names}
except Exception as e:
return {"status": "error", "detail": str(e)}
# ---------------------------------------------------------------------------
# Collection management
# ---------------------------------------------------------------------------
@app.post("/collections")
def create_collection(req: CollectionRequest):
dist_map = {"cosine": Distance.COSINE, "euclid": Distance.EUCLID, "dot": Distance.DOT}
dist = dist_map.get(req.distance.lower())
if dist is None:
raise HTTPException(400, f"Unknown distance: {req.distance}. Use cosine, euclid, or dot.")
collections = [c.name for c in client.get_collections().collections]
if req.name in collections:
raise HTTPException(409, f"Collection '{req.name}' already exists")
client.create_collection(
collection_name=req.name,
vectors_config=VectorParams(size=req.vector_dim, distance=dist),
)
return {"created": req.name, "vector_dim": req.vector_dim, "distance": req.distance}
@app.get("/collections")
def list_collections():
info = client.get_collections()
return {"collections": [c.name for c in info.collections]}
@app.get("/collections/{name}")
def collection_info(name: str):
try:
info = client.get_collection(name)
return {
"name": name,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status.value if info.status else None,
}
except Exception as e:
raise HTTPException(404, str(e))
@app.delete("/collections/{name}")
def delete_collection(name: str):
client.delete_collection(name)
return {"deleted": name}
# ---------------------------------------------------------------------------
# Upsert endpoints
# ---------------------------------------------------------------------------
@app.post("/upsert")
async def upsert_url(req: UpsertUrlRequest):
"""Embed an image by URL via CLIP, then store the vector in Qdrant."""
vector = await _embed_url(req.url)
pid = _point_id(req.id)
payload = {**req.metadata, "source_url": req.url}
col = _col(req.collection)
client.upsert(
collection_name=col,
points=[PointStruct(id=pid, vector=vector, payload=payload)],
)
return {"id": pid, "collection": col, "dim": len(vector)}
@app.post("/upsert/file")
async def upsert_file(
file: UploadFile = File(...),
id: Optional[str] = Form(None),
collection: Optional[str] = Form(None),
metadata_json: Optional[str] = Form(None),
):
"""Embed an uploaded image via CLIP, then store the vector in Qdrant."""
import json
data = await file.read()
vector = await _embed_bytes(data)
pid = _point_id(id)
payload: Dict[str, Any] = {}
if metadata_json:
try:
payload = json.loads(metadata_json)
except json.JSONDecodeError:
raise HTTPException(400, "metadata_json must be valid JSON")
col = _col(collection)
client.upsert(
collection_name=col,
points=[PointStruct(id=pid, vector=vector, payload=payload)],
)
return {"id": pid, "collection": col, "dim": len(vector)}
@app.post("/upsert/vector")
def upsert_vector(req: UpsertVectorRequest):
"""Store a pre-computed vector directly (skip CLIP embedding)."""
pid = _point_id(req.id)
col = _col(req.collection)
client.upsert(
collection_name=col,
points=[PointStruct(id=pid, vector=req.vector, payload=req.metadata)],
)
return {"id": pid, "collection": col, "dim": len(req.vector)}
# ---------------------------------------------------------------------------
# Search endpoints
# ---------------------------------------------------------------------------
@app.post("/search")
async def search_url(req: SearchUrlRequest):
"""Embed an image by URL via CLIP, then search Qdrant for similar vectors."""
vector = await _embed_url(req.url)
return _do_search(vector, req.limit, req.score_threshold, req.collection, req.filter_metadata)
@app.post("/search/file")
async def search_file(
file: UploadFile = File(...),
limit: int = Form(5),
score_threshold: Optional[float] = Form(None),
collection: Optional[str] = Form(None),
):
"""Embed an uploaded image via CLIP, then search Qdrant for similar vectors."""
data = await file.read()
vector = await _embed_bytes(data)
return _do_search(vector, int(limit), score_threshold, collection, {})
@app.post("/search/vector")
def search_vector(req: SearchVectorRequest):
"""Search Qdrant using a pre-computed vector."""
return _do_search(req.vector, req.limit, req.score_threshold, req.collection, req.filter_metadata)
def _do_search(
vector: List[float],
limit: int,
score_threshold: Optional[float],
collection: Optional[str],
filter_metadata: Dict[str, Any],
):
col = _col(collection)
qfilter = _build_filter(filter_metadata)
results = client.query_points(
collection_name=col,
query=vector,
limit=limit,
score_threshold=score_threshold,
query_filter=qfilter,
)
hits = []
for point in results.points:
hits.append({
"id": point.id,
"score": point.score,
"metadata": point.payload,
})
return {"results": hits, "collection": col, "count": len(hits)}
# ---------------------------------------------------------------------------
# Delete points
# ---------------------------------------------------------------------------
@app.post("/delete")
def delete_points(req: DeleteRequest):
col = _col(req.collection)
client.delete(
collection_name=col,
points_selector=req.ids,
)
return {"deleted": req.ids, "collection": col}
# ---------------------------------------------------------------------------
# Get point by ID
# ---------------------------------------------------------------------------
@app.get("/points/{point_id}")
def get_point(point_id: str, collection: Optional[str] = None):
col = _col(collection)
try:
points = client.retrieve(collection_name=col, ids=[point_id], with_vectors=True)
if not points:
raise HTTPException(404, f"Point '{point_id}' not found")
p = points[0]
return {
"id": p.id,
"vector": p.vector,
"metadata": p.payload,
"collection": col,
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(404, str(e))