Files
vision/qdrant/main.py

758 lines
28 KiB
Python

from __future__ import annotations
import os
import uuid
from typing import Any, Dict, List, Optional
import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from pydantic import BaseModel, Field
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
PointStruct,
VectorParams,
Filter,
FieldCondition,
MatchValue,
HnswConfigDiff,
OptimizersConfigDiff,
SearchParams,
PayloadSchemaType,
ScalarQuantizationConfig,
ScalarType,
)
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant")
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
CLIP_URL = os.getenv("CLIP_URL", "http://clip:8000")
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "images")
VECTOR_DIM = int(os.getenv("VECTOR_DIM", "512"))
# hnsw_ef at query time: higher = better recall, slightly more latency (Qdrant default ~100)
SEARCH_HNSW_EF = int(os.getenv("SEARCH_HNSW_EF", "128"))
app = FastAPI(title="Skinbase Qdrant Service", version="1.0.0")
client: QdrantClient = None # type: ignore[assignment]
# ---------------------------------------------------------------------------
# Startup / shutdown
# ---------------------------------------------------------------------------
@app.on_event("startup")
def startup():
global client
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
_ensure_collection()
def _ensure_collection():
"""Create the default collection with production-friendly defaults if it does not exist yet."""
collections = [c.name for c in client.get_collections().collections]
if COLLECTION_NAME not in collections:
client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=VECTOR_DIM, distance=Distance.COSINE),
hnsw_config=HnswConfigDiff(
m=16,
ef_construct=200, # higher than default 100 = better index quality
on_disk=False, # keep HNSW graph in RAM for fast traversal
),
optimizers_config=OptimizersConfigDiff(
indexing_threshold=20000, # start indexing after 20k accumulated vectors
default_segment_number=4, # parallelism-friendly segment count
),
)
# ---------------------------------------------------------------------------
# Request / Response models
# ---------------------------------------------------------------------------
class UpsertUrlRequest(BaseModel):
url: str
id: Optional[str] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
collection: Optional[str] = None
class UpsertVectorRequest(BaseModel):
vector: List[float]
id: Optional[str] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
collection: Optional[str] = None
class SearchUrlRequest(BaseModel):
url: str
limit: int = Field(default=5, ge=1, le=100)
score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
collection: Optional[str] = None
filter_metadata: Dict[str, Any] = Field(default_factory=dict)
hnsw_ef: Optional[int] = Field(default=None, ge=1, le=512, description="Override ef at query time. Higher = better recall, slightly higher latency.")
exact: bool = Field(default=False, description="Brute-force exact search. Avoid on large collections.")
indexed_only: bool = Field(default=False, description="Search only fully indexed segments. Useful during bulk ingest.")
class SearchVectorRequest(BaseModel):
vector: List[float]
limit: int = Field(default=5, ge=1, le=100)
score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
collection: Optional[str] = None
filter_metadata: Dict[str, Any] = Field(default_factory=dict)
hnsw_ef: Optional[int] = Field(default=None, ge=1, le=512)
exact: bool = False
indexed_only: bool = False
class DeleteRequest(BaseModel):
ids: List[str]
collection: Optional[str] = None
class CollectionRequest(BaseModel):
name: str
vector_dim: int = Field(default=512, ge=1)
distance: str = Field(default="cosine")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _col(name: Optional[str]) -> str:
return name or COLLECTION_NAME
async def _embed_url(url: str) -> List[float]:
"""Call the CLIP service to get an image embedding."""
async with httpx.AsyncClient(timeout=30) as http:
try:
r = await http.post(f"{CLIP_URL}/embed", json={"url": url})
except httpx.RequestError as e:
raise HTTPException(502, f"CLIP request failed: {str(e)}")
if r.status_code >= 400:
raise HTTPException(502, f"CLIP /embed error: {r.status_code} {r.text[:200]}")
try:
return r.json()["vector"]
except Exception:
raise HTTPException(502, f"CLIP /embed returned non-JSON: {r.status_code} {r.text[:200]}")
async def _embed_bytes(data: bytes) -> List[float]:
"""Call the CLIP service to embed uploaded file bytes."""
async with httpx.AsyncClient(timeout=30) as http:
files = {"file": ("image", data, "application/octet-stream")}
try:
r = await http.post(f"{CLIP_URL}/embed/file", files=files)
except httpx.RequestError as e:
raise HTTPException(502, f"CLIP request failed: {str(e)}")
if r.status_code >= 400:
raise HTTPException(502, f"CLIP /embed/file error: {r.status_code} {r.text[:200]}")
try:
return r.json()["vector"]
except Exception:
raise HTTPException(502, f"CLIP /embed/file returned non-JSON: {r.status_code} {r.text[:200]}")
def _build_filter(metadata: Dict[str, Any]) -> Optional[Filter]:
if not metadata:
return None
conditions = [
FieldCondition(key=k, match=MatchValue(value=v))
for k, v in metadata.items()
]
return Filter(must=conditions)
def _id_filter(original_id: str) -> Filter:
return Filter(must=[FieldCondition(key="_original_id", match=MatchValue(value=original_id))])
def _point_id(raw: Optional[str]) -> str:
"""Return a Qdrant-compatible point id.
Qdrant accepts either an unsigned integer or a UUID string (with hyphens).
If the provided `raw` value is an int or valid UUID we return it (as int or str).
Otherwise we generate a new UUID string and the caller should store the
original `raw` value in the point payload under `_original_id`.
"""
if not raw:
return str(uuid.uuid4())
# allow integer ids
try:
return int(raw)
except Exception:
pass
# allow UUID strings
try:
u = uuid.UUID(raw)
return str(u)
except Exception:
# fallback: generate a UUID
return str(uuid.uuid4())
# ---------------------------------------------------------------------------
# Health
# ---------------------------------------------------------------------------
@app.get("/health")
def health():
try:
info = client.get_collections()
names = [c.name for c in info.collections]
return {"status": "ok", "qdrant": QDRANT_HOST, "collections": names}
except Exception as e:
return {"status": "error", "detail": str(e)}
@app.get("/inspect")
def inspect():
"""Return a full diagnostic summary for every collection.
Covers: vector counts, segment counts, HNSW config, optimizer config,
quantization, payload indexes and their coverage. Designed for production
health checks and the Qdrant optimization workflow.
"""
try:
all_collections = client.get_collections().collections
except Exception as exc:
return {"status": "error", "detail": str(exc)}
result = {}
for col_desc in all_collections:
name = col_desc.name
try:
info = client.get_collection(name)
cfg = info.config
hnsw = cfg.hnsw_config
opt = cfg.optimizer_config
quant = cfg.quantization_config
params = cfg.params
# Estimate raw RAM footprint: vectors * dim * 4 bytes * 1.5 safety factor
vec_count = info.vectors_count or 0
vec_dim = (
params.vectors.size
if hasattr(params.vectors, "size")
else VECTOR_DIM
)
ram_estimate_mb = round(vec_count * vec_dim * 4 * 1.5 / 1_048_576, 1)
result[name] = {
"status": info.status.value if info.status else None,
"optimizer_status": str(info.optimizer_status) if info.optimizer_status else None,
"vectors_count": vec_count,
"indexed_vectors_count": info.indexed_vectors_count,
"points_count": info.points_count,
"segments_count": info.segments_count,
"ram_estimate_mb": ram_estimate_mb,
"hnsw": {
"m": hnsw.m,
"ef_construct": hnsw.ef_construct,
"on_disk": hnsw.on_disk,
"full_scan_threshold": hnsw.full_scan_threshold,
"max_indexing_threads": hnsw.max_indexing_threads,
} if hnsw else None,
"optimizer": {
"indexing_threshold": opt.indexing_threshold,
"default_segment_number": opt.default_segment_number,
"max_segment_size": opt.max_segment_size,
"memmap_threshold": opt.memmap_threshold,
"flush_interval_sec": opt.flush_interval_sec,
} if opt else None,
"quantization": str(quant) if quant else None,
"payload_indexes": {
k: {
"type": v.data_type.value if hasattr(v.data_type, "value") else str(v.data_type),
"points": v.points,
"coverage_pct": round(v.points / max(vec_count, 1) * 100, 1),
}
for k, v in (info.payload_schema or {}).items()
},
"payload_index_count": len(info.payload_schema or {}),
"search_hnsw_ef": SEARCH_HNSW_EF,
}
except Exception as exc:
result[name] = {"error": str(exc)}
return {"collections": result, "total": len(result)}
# ---------------------------------------------------------------------------
# Collection management
# ---------------------------------------------------------------------------
@app.post("/collections")
def create_collection(req: CollectionRequest):
dist_map = {"cosine": Distance.COSINE, "euclid": Distance.EUCLID, "dot": Distance.DOT}
dist = dist_map.get(req.distance.lower())
if dist is None:
raise HTTPException(400, f"Unknown distance: {req.distance}. Use cosine, euclid, or dot.")
collections = [c.name for c in client.get_collections().collections]
if req.name in collections:
raise HTTPException(409, f"Collection '{req.name}' already exists")
# Apply the same production defaults as _ensure_collection so all
# collections start with tuned HNSW and optimizer settings.
client.create_collection(
collection_name=req.name,
vectors_config=VectorParams(size=req.vector_dim, distance=dist),
hnsw_config=HnswConfigDiff(m=16, ef_construct=200, on_disk=False),
optimizers_config=OptimizersConfigDiff(indexing_threshold=20000, default_segment_number=4),
)
return {"created": req.name, "vector_dim": req.vector_dim, "distance": req.distance}
@app.get("/collections")
def list_collections():
info = client.get_collections()
return {"collections": [c.name for c in info.collections]}
@app.get("/collections/{name}")
def collection_info(name: str):
try:
info = client.get_collection(name)
cfg = info.config
hnsw = cfg.hnsw_config
opt = cfg.optimizer_config
quant = cfg.quantization_config
return {
"name": name,
"vectors_count": info.vectors_count,
"indexed_vectors_count": info.indexed_vectors_count,
"points_count": info.points_count,
"segments_count": info.segments_count,
"status": info.status.value if info.status else None,
"optimizer_status": str(info.optimizer_status) if info.optimizer_status else None,
"hnsw": {
"m": hnsw.m,
"ef_construct": hnsw.ef_construct,
"on_disk": hnsw.on_disk,
"full_scan_threshold": hnsw.full_scan_threshold,
"max_indexing_threads": hnsw.max_indexing_threads,
} if hnsw else None,
"optimizer": {
"indexing_threshold": opt.indexing_threshold,
"default_segment_number": opt.default_segment_number,
"max_segment_size": opt.max_segment_size,
"memmap_threshold": opt.memmap_threshold,
"flush_interval_sec": opt.flush_interval_sec,
} if opt else None,
"quantization": str(quant) if quant else None,
"payload_schema": {
k: {
"type": v.data_type.value if hasattr(v.data_type, "value") else str(v.data_type),
"points": v.points,
}
for k, v in (info.payload_schema or {}).items()
},
}
except Exception as e:
raise HTTPException(404, str(e))
@app.delete("/collections/{name}")
def delete_collection(name: str):
client.delete_collection(name)
return {"deleted": name}
# ---------------------------------------------------------------------------
# Upsert endpoints
# ---------------------------------------------------------------------------
@app.post("/upsert")
async def upsert_url(req: UpsertUrlRequest):
"""Embed an image by URL via CLIP, then store the vector in Qdrant."""
vector = await _embed_url(req.url)
pid = _point_id(req.id)
payload = {**req.metadata, "source_url": req.url}
# preserve original user-provided id if it wasn't usable as a point id
if req.id is not None and str(pid) != str(req.id):
payload["_original_id"] = req.id
col = _col(req.collection)
try:
client.upsert(
collection_name=col,
points=[PointStruct(id=pid, vector=vector, payload=payload)],
)
except Exception as e:
raise HTTPException(500, str(e))
return {"id": pid, "collection": col, "dim": len(vector)}
@app.post("/upsert/file")
async def upsert_file(
file: UploadFile = File(...),
id: Optional[str] = Form(None),
collection: Optional[str] = Form(None),
metadata_json: Optional[str] = Form(None),
):
"""Embed an uploaded image via CLIP, then store the vector in Qdrant."""
import json
data = await file.read()
vector = await _embed_bytes(data)
pid = _point_id(id)
payload: Dict[str, Any] = {}
if metadata_json:
try:
payload = json.loads(metadata_json)
except json.JSONDecodeError:
raise HTTPException(400, "metadata_json must be valid JSON")
# preserve original user-provided id if it wasn't usable as a point id
if id is not None and str(pid) != str(id):
payload["_original_id"] = id
col = _col(collection)
try:
client.upsert(
collection_name=col,
points=[PointStruct(id=pid, vector=vector, payload=payload)],
)
except Exception as e:
raise HTTPException(500, str(e))
return {"id": pid, "collection": col, "dim": len(vector)}
@app.post("/upsert/vector")
def upsert_vector(req: UpsertVectorRequest):
"""Store a pre-computed vector directly (skip CLIP embedding)."""
pid = _point_id(req.id)
col = _col(req.collection)
payload = dict(req.metadata or {})
if req.id is not None and str(pid) != str(req.id):
payload["_original_id"] = req.id
try:
client.upsert(
collection_name=col,
points=[PointStruct(id=pid, vector=req.vector, payload=payload)],
)
except Exception as e:
raise HTTPException(500, str(e))
return {"id": pid, "collection": col, "dim": len(req.vector)}
# ---------------------------------------------------------------------------
# Search endpoints
# ---------------------------------------------------------------------------
@app.post("/search")
async def search_url(req: SearchUrlRequest):
"""Embed an image by URL via CLIP, then search Qdrant for similar vectors."""
vector = await _embed_url(req.url)
return _do_search(vector, req.limit, req.score_threshold, req.collection, req.filter_metadata, req.hnsw_ef, req.exact, req.indexed_only)
@app.post("/search/file")
async def search_file(
file: UploadFile = File(...),
limit: int = Form(5),
score_threshold: Optional[float] = Form(None),
collection: Optional[str] = Form(None),
hnsw_ef: Optional[int] = Form(None),
exact: bool = Form(False),
indexed_only: bool = Form(False),
filter_metadata_json: Optional[str] = Form(None),
):
"""Embed an uploaded image via CLIP, then search Qdrant for similar vectors."""
import json
filter_metadata: Dict[str, Any] = {}
if filter_metadata_json:
try:
filter_metadata = json.loads(filter_metadata_json)
except json.JSONDecodeError:
raise HTTPException(400, "filter_metadata_json must be valid JSON")
data = await file.read()
vector = await _embed_bytes(data)
return _do_search(vector, int(limit), score_threshold, collection, filter_metadata, hnsw_ef, exact, indexed_only)
@app.post("/search/vector")
def search_vector(req: SearchVectorRequest):
"""Search Qdrant using a pre-computed vector."""
return _do_search(req.vector, req.limit, req.score_threshold, req.collection, req.filter_metadata, req.hnsw_ef, req.exact, req.indexed_only)
def _do_search(
vector: List[float],
limit: int,
score_threshold: Optional[float],
collection: Optional[str],
filter_metadata: Dict[str, Any],
hnsw_ef: Optional[int] = None,
exact: bool = False,
indexed_only: bool = False,
):
col = _col(collection)
qfilter = _build_filter(filter_metadata)
ef = hnsw_ef if hnsw_ef is not None else SEARCH_HNSW_EF
results = client.query_points(
collection_name=col,
query=vector,
limit=limit,
score_threshold=score_threshold,
query_filter=qfilter,
search_params=SearchParams(hnsw_ef=ef, exact=exact, indexed_only=indexed_only),
)
hits = []
for point in results.points:
hits.append({
"id": point.id,
"score": point.score,
"metadata": point.payload,
})
return {"results": hits, "collection": col, "count": len(hits)}
# ---------------------------------------------------------------------------
# Delete points
# ---------------------------------------------------------------------------
@app.post("/delete")
def delete_points(req: DeleteRequest):
col = _col(req.collection)
client.delete(
collection_name=col,
points_selector=req.ids,
)
return {"deleted": req.ids, "collection": col}
# ---------------------------------------------------------------------------
# Get point by ID
# ---------------------------------------------------------------------------
@app.get("/points/{point_id}")
def get_point(point_id: str, collection: Optional[str] = None):
col = _col(collection)
try:
points = client.retrieve(collection_name=col, ids=[point_id], with_vectors=True)
if not points:
raise HTTPException(404, f"Point '{point_id}' not found")
p = points[0]
return {
"id": p.id,
"vector": p.vector,
"metadata": p.payload,
"collection": col,
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(404, str(e))
@app.get("/points/by-original-id/{original_id}")
def get_point_by_original_id(original_id: str, collection: Optional[str] = None):
col = _col(collection)
try:
points, _ = client.scroll(
collection_name=col,
scroll_filter=_id_filter(original_id),
limit=1,
with_vectors=True,
with_payload=True,
)
if not points:
raise HTTPException(404, f"Point with _original_id '{original_id}' not found")
point = points[0]
return {
"id": point.id,
"vector": point.vector,
"metadata": point.payload,
"collection": col,
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(404, str(e))
# ---------------------------------------------------------------------------
# Payload index management
# ---------------------------------------------------------------------------
_SCHEMA_TYPE_MAP: Dict[str, PayloadSchemaType] = {
t.value: t for t in PayloadSchemaType
}
def _resolve_schema_type(type_str: str) -> PayloadSchemaType:
schema = _SCHEMA_TYPE_MAP.get(type_str.lower())
if schema is None:
raise HTTPException(400, f"Unknown index type '{type_str}'. Valid: {', '.join(_SCHEMA_TYPE_MAP)}")
return schema
class PayloadIndexRequest(BaseModel):
field: str
type: str = Field(default="keyword", description="keyword | integer | float | bool | geo | datetime | text | uuid")
collection: Optional[str] = None
class EnsureIndexesRequest(BaseModel):
"""List of field specs, each with 'field' and optional 'type' keys."""
fields: List[Dict[str, str]]
collection: Optional[str] = None
@app.get("/collections/{name}/indexes")
def collection_indexes(name: str):
"""List all payload indexes for a collection."""
try:
info = client.get_collection(name)
schema = info.payload_schema or {}
return {
"collection": name,
"indexes": {
k: {
"type": v.data_type.value if hasattr(v.data_type, "value") else str(v.data_type),
"points": v.points,
}
for k, v in schema.items()
},
"count": len(schema),
}
except Exception as e:
raise HTTPException(404, str(e))
@app.post("/collections/{name}/indexes")
def create_index(name: str, req: PayloadIndexRequest):
"""Create a payload index on a single field."""
col = req.collection or name
schema = _resolve_schema_type(req.type)
try:
client.create_payload_index(
collection_name=col,
field_name=req.field,
field_schema=schema,
)
return {"collection": col, "field": req.field, "type": req.type, "status": "created"}
except Exception as e:
raise HTTPException(500, str(e))
@app.post("/collections/{name}/ensure-indexes")
def ensure_indexes(name: str, req: EnsureIndexesRequest):
"""Idempotently ensure payload indexes exist for a list of fields.
Skips fields that are already indexed; only creates the missing ones.
Example body: {"fields": [{"field": "is_public", "type": "bool"}, {"field": "category_id", "type": "integer"}]}
"""
col = req.collection or name
try:
info = client.get_collection(col)
except Exception as e:
raise HTTPException(404, str(e))
existing = set(info.payload_schema.keys()) if info.payload_schema else set()
created: List[str] = []
skipped: List[str] = []
for field_spec in req.fields:
field = field_spec.get("field")
type_str = field_spec.get("type", "keyword")
if not field:
raise HTTPException(400, "Each field spec must include a 'field' key")
if field in existing:
skipped.append(field)
continue
schema = _resolve_schema_type(type_str)
try:
client.create_payload_index(
collection_name=col,
field_name=field,
field_schema=schema,
)
created.append(field)
except Exception as exc:
raise HTTPException(500, f"Failed to index '{field}': {exc}")
return {"collection": col, "created": created, "skipped": skipped}
# ---------------------------------------------------------------------------
# Collection HNSW + optimizer configuration
# ---------------------------------------------------------------------------
class CollectionConfigRequest(BaseModel):
hnsw_m: Optional[int] = Field(default=None, ge=4, le=64, description="Edges per node in the HNSW graph.")
hnsw_ef_construct: Optional[int] = Field(default=None, ge=10, le=1000, description="ef during index construction. Changes apply to new segments only.")
hnsw_on_disk: Optional[bool] = Field(default=None, description="Store HNSW graph on disk (saves RAM, slightly slower queries).")
indexing_threshold: Optional[int] = Field(default=None, ge=0, description="Min payload changes before a segment is indexed.")
default_segment_number: Optional[int] = Field(default=None, ge=1, le=32, description="Target number of segments for parallelism.")
# Scalar quantization — reduces RAM ~4x, often speeds up search on large collections.
# Set quantization_type='int8' to enable. Use always_ram=True to keep quantized
# vectors in RAM (recommended on VPS with limited memory but fast disk).
quantization_type: Optional[str] = Field(default=None, description="Enable scalar quantization: 'int8'. Set to null to keep current setting.")
quantization_quantile: float = Field(default=0.99, ge=0.5, le=1.0, description="Fraction of vectors used to calibrate quantization range (0.99 recommended).")
quantization_always_ram: bool = Field(default=True, description="Keep quantized vectors in RAM even when raw vectors are on disk.")
@app.post("/collections/{name}/configure")
def configure_collection(name: str, req: CollectionConfigRequest):
"""Apply HNSW and optimizer configuration updates to an existing collection.
Changes are applied in-place without data loss or re-ingestion.
Note: hnsw_m and hnsw_ef_construct only affect newly created segments.
"""
hnsw_kwargs = {k: v for k, v in {
"m": req.hnsw_m,
"ef_construct": req.hnsw_ef_construct,
"on_disk": req.hnsw_on_disk,
}.items() if v is not None}
opt_kwargs = {k: v for k, v in {
"indexing_threshold": req.indexing_threshold,
"default_segment_number": req.default_segment_number,
}.items() if v is not None}
# Build optional scalar quantization config
quant_config = None
if req.quantization_type is not None:
if req.quantization_type.lower() != "int8":
raise HTTPException(400, f"Unsupported quantization_type '{req.quantization_type}'. Only 'int8' is supported.")
quant_config = ScalarQuantizationConfig(
type=ScalarType.INT8,
quantile=req.quantization_quantile,
always_ram=req.quantization_always_ram,
)
if not hnsw_kwargs and not opt_kwargs and quant_config is None:
raise HTTPException(400, "No configuration fields provided")
try:
client.update_collection(
collection_name=name,
hnsw_config=HnswConfigDiff(**hnsw_kwargs) if hnsw_kwargs else None,
optimizers_config=OptimizersConfigDiff(**opt_kwargs) if opt_kwargs else None,
quantization_config=quant_config,
)
return {
"collection": name,
"status": "updated",
"hnsw_changes": hnsw_kwargs,
"optimizer_changes": opt_kwargs,
"quantization": {"type": req.quantization_type, "quantile": req.quantization_quantile, "always_ram": req.quantization_always_ram} if quant_config else None,
}
except Exception as exc:
raise HTTPException(500, str(exc))