441 lines
14 KiB
Python
441 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import uuid
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import httpx
|
|
import numpy as np
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
|
from pydantic import BaseModel, Field
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.models import (
|
|
Distance,
|
|
PointStruct,
|
|
VectorParams,
|
|
Filter,
|
|
FieldCondition,
|
|
MatchValue,
|
|
)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config
|
|
# ---------------------------------------------------------------------------
|
|
|
|
QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant")
|
|
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
|
|
CLIP_URL = os.getenv("CLIP_URL", "http://clip:8000")
|
|
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "images")
|
|
VECTOR_DIM = int(os.getenv("VECTOR_DIM", "512"))
|
|
|
|
app = FastAPI(title="Skinbase Qdrant Service", version="1.0.0")
|
|
client: QdrantClient = None # type: ignore[assignment]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Startup / shutdown
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.on_event("startup")
|
|
def startup():
|
|
global client
|
|
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
|
|
_ensure_collection()
|
|
|
|
|
|
def _ensure_collection():
|
|
"""Create the default collection if it does not exist yet."""
|
|
collections = [c.name for c in client.get_collections().collections]
|
|
if COLLECTION_NAME not in collections:
|
|
client.create_collection(
|
|
collection_name=COLLECTION_NAME,
|
|
vectors_config=VectorParams(size=VECTOR_DIM, distance=Distance.COSINE),
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Request / Response models
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class UpsertUrlRequest(BaseModel):
|
|
url: str
|
|
id: Optional[str] = None
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
collection: Optional[str] = None
|
|
|
|
|
|
class UpsertVectorRequest(BaseModel):
|
|
vector: List[float]
|
|
id: Optional[str] = None
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
collection: Optional[str] = None
|
|
|
|
|
|
class SearchUrlRequest(BaseModel):
|
|
url: str
|
|
limit: int = Field(default=5, ge=1, le=100)
|
|
score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
|
collection: Optional[str] = None
|
|
filter_metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
class SearchVectorRequest(BaseModel):
|
|
vector: List[float]
|
|
limit: int = Field(default=5, ge=1, le=100)
|
|
score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
|
collection: Optional[str] = None
|
|
filter_metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
class DeleteRequest(BaseModel):
|
|
ids: List[str]
|
|
collection: Optional[str] = None
|
|
|
|
|
|
class CollectionRequest(BaseModel):
|
|
name: str
|
|
vector_dim: int = Field(default=512, ge=1)
|
|
distance: str = Field(default="cosine")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _col(name: Optional[str]) -> str:
|
|
return name or COLLECTION_NAME
|
|
|
|
|
|
async def _embed_url(url: str) -> List[float]:
|
|
"""Call the CLIP service to get an image embedding."""
|
|
async with httpx.AsyncClient(timeout=30) as http:
|
|
try:
|
|
r = await http.post(f"{CLIP_URL}/embed", json={"url": url})
|
|
except httpx.RequestError as e:
|
|
raise HTTPException(502, f"CLIP request failed: {str(e)}")
|
|
if r.status_code >= 400:
|
|
raise HTTPException(502, f"CLIP /embed error: {r.status_code} {r.text[:200]}")
|
|
try:
|
|
return r.json()["vector"]
|
|
except Exception:
|
|
raise HTTPException(502, f"CLIP /embed returned non-JSON: {r.status_code} {r.text[:200]}")
|
|
|
|
|
|
async def _embed_bytes(data: bytes) -> List[float]:
|
|
"""Call the CLIP service to embed uploaded file bytes."""
|
|
async with httpx.AsyncClient(timeout=30) as http:
|
|
files = {"file": ("image", data, "application/octet-stream")}
|
|
try:
|
|
r = await http.post(f"{CLIP_URL}/embed/file", files=files)
|
|
except httpx.RequestError as e:
|
|
raise HTTPException(502, f"CLIP request failed: {str(e)}")
|
|
if r.status_code >= 400:
|
|
raise HTTPException(502, f"CLIP /embed/file error: {r.status_code} {r.text[:200]}")
|
|
try:
|
|
return r.json()["vector"]
|
|
except Exception:
|
|
raise HTTPException(502, f"CLIP /embed/file returned non-JSON: {r.status_code} {r.text[:200]}")
|
|
|
|
|
|
def _build_filter(metadata: Dict[str, Any]) -> Optional[Filter]:
|
|
if not metadata:
|
|
return None
|
|
conditions = [
|
|
FieldCondition(key=k, match=MatchValue(value=v))
|
|
for k, v in metadata.items()
|
|
]
|
|
return Filter(must=conditions)
|
|
|
|
|
|
def _id_filter(original_id: str) -> Filter:
|
|
return Filter(must=[FieldCondition(key="_original_id", match=MatchValue(value=original_id))])
|
|
|
|
|
|
def _point_id(raw: Optional[str]) -> str:
|
|
"""Return a Qdrant-compatible point id.
|
|
|
|
Qdrant accepts either an unsigned integer or a UUID string (with hyphens).
|
|
If the provided `raw` value is an int or valid UUID we return it (as int or str).
|
|
Otherwise we generate a new UUID string and the caller should store the
|
|
original `raw` value in the point payload under `_original_id`.
|
|
"""
|
|
if not raw:
|
|
return str(uuid.uuid4())
|
|
# allow integer ids
|
|
try:
|
|
return int(raw)
|
|
except Exception:
|
|
pass
|
|
# allow UUID strings
|
|
try:
|
|
u = uuid.UUID(raw)
|
|
return str(u)
|
|
except Exception:
|
|
# fallback: generate a UUID
|
|
return str(uuid.uuid4())
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Health
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
try:
|
|
info = client.get_collections()
|
|
names = [c.name for c in info.collections]
|
|
return {"status": "ok", "qdrant": QDRANT_HOST, "collections": names}
|
|
except Exception as e:
|
|
return {"status": "error", "detail": str(e)}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Collection management
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.post("/collections")
|
|
def create_collection(req: CollectionRequest):
|
|
dist_map = {"cosine": Distance.COSINE, "euclid": Distance.EUCLID, "dot": Distance.DOT}
|
|
dist = dist_map.get(req.distance.lower())
|
|
if dist is None:
|
|
raise HTTPException(400, f"Unknown distance: {req.distance}. Use cosine, euclid, or dot.")
|
|
|
|
collections = [c.name for c in client.get_collections().collections]
|
|
if req.name in collections:
|
|
raise HTTPException(409, f"Collection '{req.name}' already exists")
|
|
|
|
client.create_collection(
|
|
collection_name=req.name,
|
|
vectors_config=VectorParams(size=req.vector_dim, distance=dist),
|
|
)
|
|
return {"created": req.name, "vector_dim": req.vector_dim, "distance": req.distance}
|
|
|
|
|
|
@app.get("/collections")
|
|
def list_collections():
|
|
info = client.get_collections()
|
|
return {"collections": [c.name for c in info.collections]}
|
|
|
|
|
|
@app.get("/collections/{name}")
|
|
def collection_info(name: str):
|
|
try:
|
|
info = client.get_collection(name)
|
|
return {
|
|
"name": name,
|
|
"vectors_count": info.vectors_count,
|
|
"points_count": info.points_count,
|
|
"status": info.status.value if info.status else None,
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(404, str(e))
|
|
|
|
|
|
@app.delete("/collections/{name}")
|
|
def delete_collection(name: str):
|
|
client.delete_collection(name)
|
|
return {"deleted": name}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Upsert endpoints
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.post("/upsert")
|
|
async def upsert_url(req: UpsertUrlRequest):
|
|
"""Embed an image by URL via CLIP, then store the vector in Qdrant."""
|
|
vector = await _embed_url(req.url)
|
|
pid = _point_id(req.id)
|
|
payload = {**req.metadata, "source_url": req.url}
|
|
# preserve original user-provided id if it wasn't usable as a point id
|
|
if req.id is not None and str(pid) != str(req.id):
|
|
payload["_original_id"] = req.id
|
|
col = _col(req.collection)
|
|
|
|
try:
|
|
client.upsert(
|
|
collection_name=col,
|
|
points=[PointStruct(id=pid, vector=vector, payload=payload)],
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(500, str(e))
|
|
|
|
return {"id": pid, "collection": col, "dim": len(vector)}
|
|
|
|
|
|
@app.post("/upsert/file")
|
|
async def upsert_file(
|
|
file: UploadFile = File(...),
|
|
id: Optional[str] = Form(None),
|
|
collection: Optional[str] = Form(None),
|
|
metadata_json: Optional[str] = Form(None),
|
|
):
|
|
"""Embed an uploaded image via CLIP, then store the vector in Qdrant."""
|
|
import json
|
|
|
|
data = await file.read()
|
|
vector = await _embed_bytes(data)
|
|
pid = _point_id(id)
|
|
|
|
payload: Dict[str, Any] = {}
|
|
if metadata_json:
|
|
try:
|
|
payload = json.loads(metadata_json)
|
|
except json.JSONDecodeError:
|
|
raise HTTPException(400, "metadata_json must be valid JSON")
|
|
# preserve original user-provided id if it wasn't usable as a point id
|
|
if id is not None and str(pid) != str(id):
|
|
payload["_original_id"] = id
|
|
|
|
col = _col(collection)
|
|
try:
|
|
client.upsert(
|
|
collection_name=col,
|
|
points=[PointStruct(id=pid, vector=vector, payload=payload)],
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(500, str(e))
|
|
|
|
return {"id": pid, "collection": col, "dim": len(vector)}
|
|
|
|
|
|
@app.post("/upsert/vector")
|
|
def upsert_vector(req: UpsertVectorRequest):
|
|
"""Store a pre-computed vector directly (skip CLIP embedding)."""
|
|
pid = _point_id(req.id)
|
|
col = _col(req.collection)
|
|
payload = dict(req.metadata or {})
|
|
if req.id is not None and str(pid) != str(req.id):
|
|
payload["_original_id"] = req.id
|
|
try:
|
|
client.upsert(
|
|
collection_name=col,
|
|
points=[PointStruct(id=pid, vector=req.vector, payload=payload)],
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(500, str(e))
|
|
return {"id": pid, "collection": col, "dim": len(req.vector)}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Search endpoints
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.post("/search")
|
|
async def search_url(req: SearchUrlRequest):
|
|
"""Embed an image by URL via CLIP, then search Qdrant for similar vectors."""
|
|
vector = await _embed_url(req.url)
|
|
return _do_search(vector, req.limit, req.score_threshold, req.collection, req.filter_metadata)
|
|
|
|
|
|
@app.post("/search/file")
|
|
async def search_file(
|
|
file: UploadFile = File(...),
|
|
limit: int = Form(5),
|
|
score_threshold: Optional[float] = Form(None),
|
|
collection: Optional[str] = Form(None),
|
|
):
|
|
"""Embed an uploaded image via CLIP, then search Qdrant for similar vectors."""
|
|
data = await file.read()
|
|
vector = await _embed_bytes(data)
|
|
return _do_search(vector, int(limit), score_threshold, collection, {})
|
|
|
|
|
|
@app.post("/search/vector")
|
|
def search_vector(req: SearchVectorRequest):
|
|
"""Search Qdrant using a pre-computed vector."""
|
|
return _do_search(req.vector, req.limit, req.score_threshold, req.collection, req.filter_metadata)
|
|
|
|
|
|
def _do_search(
|
|
vector: List[float],
|
|
limit: int,
|
|
score_threshold: Optional[float],
|
|
collection: Optional[str],
|
|
filter_metadata: Dict[str, Any],
|
|
):
|
|
col = _col(collection)
|
|
qfilter = _build_filter(filter_metadata)
|
|
|
|
results = client.query_points(
|
|
collection_name=col,
|
|
query=vector,
|
|
limit=limit,
|
|
score_threshold=score_threshold,
|
|
query_filter=qfilter,
|
|
)
|
|
|
|
hits = []
|
|
for point in results.points:
|
|
hits.append({
|
|
"id": point.id,
|
|
"score": point.score,
|
|
"metadata": point.payload,
|
|
})
|
|
|
|
return {"results": hits, "collection": col, "count": len(hits)}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Delete points
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.post("/delete")
|
|
def delete_points(req: DeleteRequest):
|
|
col = _col(req.collection)
|
|
client.delete(
|
|
collection_name=col,
|
|
points_selector=req.ids,
|
|
)
|
|
return {"deleted": req.ids, "collection": col}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Get point by ID
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.get("/points/{point_id}")
|
|
def get_point(point_id: str, collection: Optional[str] = None):
|
|
col = _col(collection)
|
|
try:
|
|
points = client.retrieve(collection_name=col, ids=[point_id], with_vectors=True)
|
|
if not points:
|
|
raise HTTPException(404, f"Point '{point_id}' not found")
|
|
p = points[0]
|
|
return {
|
|
"id": p.id,
|
|
"vector": p.vector,
|
|
"metadata": p.payload,
|
|
"collection": col,
|
|
}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(404, str(e))
|
|
|
|
|
|
@app.get("/points/by-original-id/{original_id}")
|
|
def get_point_by_original_id(original_id: str, collection: Optional[str] = None):
|
|
col = _col(collection)
|
|
try:
|
|
points, _ = client.scroll(
|
|
collection_name=col,
|
|
scroll_filter=_id_filter(original_id),
|
|
limit=1,
|
|
with_vectors=True,
|
|
with_payload=True,
|
|
)
|
|
if not points:
|
|
raise HTTPException(404, f"Point with _original_id '{original_id}' not found")
|
|
point = points[0]
|
|
return {
|
|
"id": point.id,
|
|
"vector": point.vector,
|
|
"metadata": point.payload,
|
|
"collection": col,
|
|
}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(404, str(e))
|