diff --git a/docker-compose.yml b/docker-compose.yml index b1fa79a..51af856 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -77,6 +77,7 @@ services: - CLIP_URL=http://clip:8000 - COLLECTION_NAME=images - VECTOR_DIM=512 + - SEARCH_HNSW_EF=128 depends_on: qdrant: condition: service_healthy diff --git a/gateway/main.py b/gateway/main.py index 67f7d23..8d61883 100644 --- a/gateway/main.py +++ b/gateway/main.py @@ -416,3 +416,33 @@ async def cards_render_meta(payload: Dict[str, Any]): """Return crop and layout metadata for a card render (no image produced).""" async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: return await _post_json(client, f"{CARD_RENDERER_URL}/render/meta", payload) + + +# ---- Qdrant administration endpoints (index management + collection config) ---- + +@app.get("/vectors/collections/{name}/indexes") +async def vectors_collection_indexes(name: str): + """List payload indexes for a collection.""" + async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: + return await _get_json(client, f"{QDRANT_SVC_URL}/collections/{name}/indexes") + + +@app.post("/vectors/collections/{name}/indexes") +async def vectors_create_payload_index(name: str, payload: Dict[str, Any]): + """Create a payload index on a field in a collection.""" + async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: + return await _post_json(client, f"{QDRANT_SVC_URL}/collections/{name}/indexes", payload) + + +@app.post("/vectors/collections/{name}/ensure-indexes") +async def vectors_ensure_indexes(name: str, payload: Dict[str, Any]): + """Idempotently ensure payload indexes exist for a list of fields.""" + async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: + return await _post_json(client, f"{QDRANT_SVC_URL}/collections/{name}/ensure-indexes", payload) + + +@app.post("/vectors/collections/{name}/configure") +async def vectors_configure_collection(name: str, payload: Dict[str, Any]): + """Update HNSW and optimizer configuration for a collection.""" + async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: + return await _post_json(client, f"{QDRANT_SVC_URL}/collections/{name}/configure", payload) diff --git a/qdrant/main.py b/qdrant/main.py index b6f345a..5d50687 100644 --- a/qdrant/main.py +++ b/qdrant/main.py @@ -16,6 +16,10 @@ from qdrant_client.models import ( Filter, FieldCondition, MatchValue, + HnswConfigDiff, + OptimizersConfigDiff, + SearchParams, + PayloadSchemaType, ) # --------------------------------------------------------------------------- @@ -27,6 +31,8 @@ QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333")) CLIP_URL = os.getenv("CLIP_URL", "http://clip:8000") COLLECTION_NAME = os.getenv("COLLECTION_NAME", "images") VECTOR_DIM = int(os.getenv("VECTOR_DIM", "512")) +# hnsw_ef at query time: higher = better recall, slightly more latency (Qdrant default ~100) +SEARCH_HNSW_EF = int(os.getenv("SEARCH_HNSW_EF", "128")) app = FastAPI(title="Skinbase Qdrant Service", version="1.0.0") client: QdrantClient = None # type: ignore[assignment] @@ -44,12 +50,21 @@ def startup(): def _ensure_collection(): - """Create the default collection if it does not exist yet.""" + """Create the default collection with production-friendly defaults if it does not exist yet.""" collections = [c.name for c in client.get_collections().collections] if COLLECTION_NAME not in collections: client.create_collection( collection_name=COLLECTION_NAME, vectors_config=VectorParams(size=VECTOR_DIM, distance=Distance.COSINE), + hnsw_config=HnswConfigDiff( + m=16, + ef_construct=200, # higher than default 100 = better index quality + on_disk=False, # keep HNSW graph in RAM for fast traversal + ), + optimizers_config=OptimizersConfigDiff( + indexing_threshold=20000, # start indexing after 20k accumulated vectors + default_segment_number=4, # parallelism-friendly segment count + ), ) @@ -77,6 +92,9 @@ class SearchUrlRequest(BaseModel): score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0) collection: Optional[str] = None filter_metadata: Dict[str, Any] = Field(default_factory=dict) + hnsw_ef: Optional[int] = Field(default=None, ge=1, le=512, description="Override ef at query time. Higher = better recall, slightly higher latency.") + exact: bool = Field(default=False, description="Brute-force exact search. Avoid on large collections.") + indexed_only: bool = Field(default=False, description="Search only fully indexed segments. Useful during bulk ingest.") class SearchVectorRequest(BaseModel): @@ -85,6 +103,9 @@ class SearchVectorRequest(BaseModel): score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0) collection: Optional[str] = None filter_metadata: Dict[str, Any] = Field(default_factory=dict) + hnsw_ef: Optional[int] = Field(default=None, ge=1, le=512) + exact: bool = False + indexed_only: bool = False class DeleteRequest(BaseModel): @@ -221,11 +242,40 @@ def list_collections(): def collection_info(name: str): try: info = client.get_collection(name) + cfg = info.config + hnsw = cfg.hnsw_config + opt = cfg.optimizer_config + quant = cfg.quantization_config return { "name": name, "vectors_count": info.vectors_count, + "indexed_vectors_count": info.indexed_vectors_count, "points_count": info.points_count, + "segments_count": info.segments_count, "status": info.status.value if info.status else None, + "optimizer_status": str(info.optimizer_status) if info.optimizer_status else None, + "hnsw": { + "m": hnsw.m, + "ef_construct": hnsw.ef_construct, + "on_disk": hnsw.on_disk, + "full_scan_threshold": hnsw.full_scan_threshold, + "max_indexing_threads": hnsw.max_indexing_threads, + } if hnsw else None, + "optimizer": { + "indexing_threshold": opt.indexing_threshold, + "default_segment_number": opt.default_segment_number, + "max_segment_size": opt.max_segment_size, + "memmap_threshold": opt.memmap_threshold, + "flush_interval_sec": opt.flush_interval_sec, + } if opt else None, + "quantization": str(quant) if quant else None, + "payload_schema": { + k: { + "type": v.data_type.value if hasattr(v.data_type, "value") else str(v.data_type), + "points": v.points, + } + for k, v in (info.payload_schema or {}).items() + }, } except Exception as e: raise HTTPException(404, str(e)) @@ -325,7 +375,7 @@ def upsert_vector(req: UpsertVectorRequest): async def search_url(req: SearchUrlRequest): """Embed an image by URL via CLIP, then search Qdrant for similar vectors.""" vector = await _embed_url(req.url) - return _do_search(vector, req.limit, req.score_threshold, req.collection, req.filter_metadata) + return _do_search(vector, req.limit, req.score_threshold, req.collection, req.filter_metadata, req.hnsw_ef, req.exact, req.indexed_only) @app.post("/search/file") @@ -344,7 +394,7 @@ async def search_file( @app.post("/search/vector") def search_vector(req: SearchVectorRequest): """Search Qdrant using a pre-computed vector.""" - return _do_search(req.vector, req.limit, req.score_threshold, req.collection, req.filter_metadata) + return _do_search(req.vector, req.limit, req.score_threshold, req.collection, req.filter_metadata, req.hnsw_ef, req.exact, req.indexed_only) def _do_search( @@ -353,9 +403,13 @@ def _do_search( score_threshold: Optional[float], collection: Optional[str], filter_metadata: Dict[str, Any], + hnsw_ef: Optional[int] = None, + exact: bool = False, + indexed_only: bool = False, ): col = _col(collection) qfilter = _build_filter(filter_metadata) + ef = hnsw_ef if hnsw_ef is not None else SEARCH_HNSW_EF results = client.query_points( collection_name=col, @@ -363,6 +417,7 @@ def _do_search( limit=limit, score_threshold=score_threshold, query_filter=qfilter, + search_params=SearchParams(hnsw_ef=ef, exact=exact, indexed_only=indexed_only), ) hits = [] @@ -438,3 +493,156 @@ def get_point_by_original_id(original_id: str, collection: Optional[str] = None) raise except Exception as e: raise HTTPException(404, str(e)) + + +# --------------------------------------------------------------------------- +# Payload index management +# --------------------------------------------------------------------------- + +_SCHEMA_TYPE_MAP: Dict[str, PayloadSchemaType] = { + t.value: t for t in PayloadSchemaType +} + + +def _resolve_schema_type(type_str: str) -> PayloadSchemaType: + schema = _SCHEMA_TYPE_MAP.get(type_str.lower()) + if schema is None: + raise HTTPException(400, f"Unknown index type '{type_str}'. Valid: {', '.join(_SCHEMA_TYPE_MAP)}") + return schema + + +class PayloadIndexRequest(BaseModel): + field: str + type: str = Field(default="keyword", description="keyword | integer | float | bool | geo | datetime | text | uuid") + collection: Optional[str] = None + + +class EnsureIndexesRequest(BaseModel): + """List of field specs, each with 'field' and optional 'type' keys.""" + fields: List[Dict[str, str]] + collection: Optional[str] = None + + +@app.get("/collections/{name}/indexes") +def collection_indexes(name: str): + """List all payload indexes for a collection.""" + try: + info = client.get_collection(name) + schema = info.payload_schema or {} + return { + "collection": name, + "indexes": { + k: { + "type": v.data_type.value if hasattr(v.data_type, "value") else str(v.data_type), + "points": v.points, + } + for k, v in schema.items() + }, + "count": len(schema), + } + except Exception as e: + raise HTTPException(404, str(e)) + + +@app.post("/collections/{name}/indexes") +def create_index(name: str, req: PayloadIndexRequest): + """Create a payload index on a single field.""" + col = req.collection or name + schema = _resolve_schema_type(req.type) + try: + client.create_payload_index( + collection_name=col, + field_name=req.field, + field_schema=schema, + ) + return {"collection": col, "field": req.field, "type": req.type, "status": "created"} + except Exception as e: + raise HTTPException(500, str(e)) + + +@app.post("/collections/{name}/ensure-indexes") +def ensure_indexes(name: str, req: EnsureIndexesRequest): + """Idempotently ensure payload indexes exist for a list of fields. + + Skips fields that are already indexed; only creates the missing ones. + Example body: {"fields": [{"field": "is_public", "type": "bool"}, {"field": "category_id", "type": "integer"}]} + """ + col = req.collection or name + try: + info = client.get_collection(col) + except Exception as e: + raise HTTPException(404, str(e)) + + existing = set(info.payload_schema.keys()) if info.payload_schema else set() + created: List[str] = [] + skipped: List[str] = [] + + for field_spec in req.fields: + field = field_spec.get("field") + type_str = field_spec.get("type", "keyword") + if not field: + raise HTTPException(400, "Each field spec must include a 'field' key") + if field in existing: + skipped.append(field) + continue + schema = _resolve_schema_type(type_str) + try: + client.create_payload_index( + collection_name=col, + field_name=field, + field_schema=schema, + ) + created.append(field) + except Exception as exc: + raise HTTPException(500, f"Failed to index '{field}': {exc}") + + return {"collection": col, "created": created, "skipped": skipped} + + +# --------------------------------------------------------------------------- +# Collection HNSW + optimizer configuration +# --------------------------------------------------------------------------- + +class CollectionConfigRequest(BaseModel): + hnsw_m: Optional[int] = Field(default=None, ge=4, le=64, description="Edges per node in the HNSW graph.") + hnsw_ef_construct: Optional[int] = Field(default=None, ge=10, le=1000, description="ef during index construction. Changes apply to new segments only.") + hnsw_on_disk: Optional[bool] = Field(default=None, description="Store HNSW graph on disk (saves RAM, slightly slower queries).") + indexing_threshold: Optional[int] = Field(default=None, ge=0, description="Min payload changes before a segment is indexed.") + default_segment_number: Optional[int] = Field(default=None, ge=1, le=32, description="Target number of segments for parallelism.") + + +@app.post("/collections/{name}/configure") +def configure_collection(name: str, req: CollectionConfigRequest): + """Apply HNSW and optimizer configuration updates to an existing collection. + + Changes are applied in-place without data loss or re-ingestion. + Note: hnsw_m and hnsw_ef_construct only affect newly created segments. + """ + hnsw_kwargs = {k: v for k, v in { + "m": req.hnsw_m, + "ef_construct": req.hnsw_ef_construct, + "on_disk": req.hnsw_on_disk, + }.items() if v is not None} + + opt_kwargs = {k: v for k, v in { + "indexing_threshold": req.indexing_threshold, + "default_segment_number": req.default_segment_number, + }.items() if v is not None} + + if not hnsw_kwargs and not opt_kwargs: + raise HTTPException(400, "No configuration fields provided") + + try: + client.update_collection( + collection_name=name, + hnsw_config=HnswConfigDiff(**hnsw_kwargs) if hnsw_kwargs else None, + optimizers_config=OptimizersConfigDiff(**opt_kwargs) if opt_kwargs else None, + ) + return { + "collection": name, + "status": "updated", + "hnsw_changes": hnsw_kwargs, + "optimizer_changes": opt_kwargs, + } + except Exception as exc: + raise HTTPException(500, str(exc))