from __future__ import annotations import os import uuid from typing import Any, Dict, List, Optional import httpx import numpy as np from fastapi import FastAPI, HTTPException, UploadFile, File, Form from pydantic import BaseModel, Field from qdrant_client import QdrantClient from qdrant_client.models import ( Distance, PointStruct, VectorParams, Filter, FieldCondition, MatchValue, ) # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant") QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333")) CLIP_URL = os.getenv("CLIP_URL", "http://clip:8000") COLLECTION_NAME = os.getenv("COLLECTION_NAME", "images") VECTOR_DIM = int(os.getenv("VECTOR_DIM", "512")) app = FastAPI(title="Skinbase Qdrant Service", version="1.0.0") client: QdrantClient = None # type: ignore[assignment] # --------------------------------------------------------------------------- # Startup / shutdown # --------------------------------------------------------------------------- @app.on_event("startup") def startup(): global client client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT) _ensure_collection() def _ensure_collection(): """Create the default collection if it does not exist yet.""" collections = [c.name for c in client.get_collections().collections] if COLLECTION_NAME not in collections: client.create_collection( collection_name=COLLECTION_NAME, vectors_config=VectorParams(size=VECTOR_DIM, distance=Distance.COSINE), ) # --------------------------------------------------------------------------- # Request / Response models # --------------------------------------------------------------------------- class UpsertUrlRequest(BaseModel): url: str id: Optional[str] = None metadata: Dict[str, Any] = Field(default_factory=dict) collection: Optional[str] = None class UpsertVectorRequest(BaseModel): vector: List[float] id: Optional[str] = None metadata: Dict[str, Any] = Field(default_factory=dict) collection: Optional[str] = None class SearchUrlRequest(BaseModel): url: str limit: int = Field(default=5, ge=1, le=100) score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0) collection: Optional[str] = None filter_metadata: Dict[str, Any] = Field(default_factory=dict) class SearchVectorRequest(BaseModel): vector: List[float] limit: int = Field(default=5, ge=1, le=100) score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0) collection: Optional[str] = None filter_metadata: Dict[str, Any] = Field(default_factory=dict) class DeleteRequest(BaseModel): ids: List[str] collection: Optional[str] = None class CollectionRequest(BaseModel): name: str vector_dim: int = Field(default=512, ge=1) distance: str = Field(default="cosine") # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _col(name: Optional[str]) -> str: return name or COLLECTION_NAME async def _embed_url(url: str) -> List[float]: """Call the CLIP service to get an image embedding.""" async with httpx.AsyncClient(timeout=30) as http: r = await http.post(f"{CLIP_URL}/embed", json={"url": url}) if r.status_code >= 400: raise HTTPException(502, f"CLIP /embed error: {r.status_code} {r.text[:200]}") return r.json()["vector"] async def _embed_bytes(data: bytes) -> List[float]: """Call the CLIP service to embed uploaded file bytes.""" async with httpx.AsyncClient(timeout=30) as http: files = {"file": ("image", data, "application/octet-stream")} r = await http.post(f"{CLIP_URL}/embed/file", files=files) if r.status_code >= 400: raise HTTPException(502, f"CLIP /embed/file error: {r.status_code} {r.text[:200]}") return r.json()["vector"] def _build_filter(metadata: Dict[str, Any]) -> Optional[Filter]: if not metadata: return None conditions = [ FieldCondition(key=k, match=MatchValue(value=v)) for k, v in metadata.items() ] return Filter(must=conditions) def _point_id(raw: Optional[str]) -> str: return raw or uuid.uuid4().hex # --------------------------------------------------------------------------- # Health # --------------------------------------------------------------------------- @app.get("/health") def health(): try: info = client.get_collections() names = [c.name for c in info.collections] return {"status": "ok", "qdrant": QDRANT_HOST, "collections": names} except Exception as e: return {"status": "error", "detail": str(e)} # --------------------------------------------------------------------------- # Collection management # --------------------------------------------------------------------------- @app.post("/collections") def create_collection(req: CollectionRequest): dist_map = {"cosine": Distance.COSINE, "euclid": Distance.EUCLID, "dot": Distance.DOT} dist = dist_map.get(req.distance.lower()) if dist is None: raise HTTPException(400, f"Unknown distance: {req.distance}. Use cosine, euclid, or dot.") collections = [c.name for c in client.get_collections().collections] if req.name in collections: raise HTTPException(409, f"Collection '{req.name}' already exists") client.create_collection( collection_name=req.name, vectors_config=VectorParams(size=req.vector_dim, distance=dist), ) return {"created": req.name, "vector_dim": req.vector_dim, "distance": req.distance} @app.get("/collections") def list_collections(): info = client.get_collections() return {"collections": [c.name for c in info.collections]} @app.get("/collections/{name}") def collection_info(name: str): try: info = client.get_collection(name) return { "name": name, "vectors_count": info.vectors_count, "points_count": info.points_count, "status": info.status.value if info.status else None, } except Exception as e: raise HTTPException(404, str(e)) @app.delete("/collections/{name}") def delete_collection(name: str): client.delete_collection(name) return {"deleted": name} # --------------------------------------------------------------------------- # Upsert endpoints # --------------------------------------------------------------------------- @app.post("/upsert") async def upsert_url(req: UpsertUrlRequest): """Embed an image by URL via CLIP, then store the vector in Qdrant.""" vector = await _embed_url(req.url) pid = _point_id(req.id) payload = {**req.metadata, "source_url": req.url} col = _col(req.collection) client.upsert( collection_name=col, points=[PointStruct(id=pid, vector=vector, payload=payload)], ) return {"id": pid, "collection": col, "dim": len(vector)} @app.post("/upsert/file") async def upsert_file( file: UploadFile = File(...), id: Optional[str] = Form(None), collection: Optional[str] = Form(None), metadata_json: Optional[str] = Form(None), ): """Embed an uploaded image via CLIP, then store the vector in Qdrant.""" import json data = await file.read() vector = await _embed_bytes(data) pid = _point_id(id) payload: Dict[str, Any] = {} if metadata_json: try: payload = json.loads(metadata_json) except json.JSONDecodeError: raise HTTPException(400, "metadata_json must be valid JSON") col = _col(collection) client.upsert( collection_name=col, points=[PointStruct(id=pid, vector=vector, payload=payload)], ) return {"id": pid, "collection": col, "dim": len(vector)} @app.post("/upsert/vector") def upsert_vector(req: UpsertVectorRequest): """Store a pre-computed vector directly (skip CLIP embedding).""" pid = _point_id(req.id) col = _col(req.collection) client.upsert( collection_name=col, points=[PointStruct(id=pid, vector=req.vector, payload=req.metadata)], ) return {"id": pid, "collection": col, "dim": len(req.vector)} # --------------------------------------------------------------------------- # Search endpoints # --------------------------------------------------------------------------- @app.post("/search") async def search_url(req: SearchUrlRequest): """Embed an image by URL via CLIP, then search Qdrant for similar vectors.""" vector = await _embed_url(req.url) return _do_search(vector, req.limit, req.score_threshold, req.collection, req.filter_metadata) @app.post("/search/file") async def search_file( file: UploadFile = File(...), limit: int = Form(5), score_threshold: Optional[float] = Form(None), collection: Optional[str] = Form(None), ): """Embed an uploaded image via CLIP, then search Qdrant for similar vectors.""" data = await file.read() vector = await _embed_bytes(data) return _do_search(vector, int(limit), score_threshold, collection, {}) @app.post("/search/vector") def search_vector(req: SearchVectorRequest): """Search Qdrant using a pre-computed vector.""" return _do_search(req.vector, req.limit, req.score_threshold, req.collection, req.filter_metadata) def _do_search( vector: List[float], limit: int, score_threshold: Optional[float], collection: Optional[str], filter_metadata: Dict[str, Any], ): col = _col(collection) qfilter = _build_filter(filter_metadata) results = client.query_points( collection_name=col, query=vector, limit=limit, score_threshold=score_threshold, query_filter=qfilter, ) hits = [] for point in results.points: hits.append({ "id": point.id, "score": point.score, "metadata": point.payload, }) return {"results": hits, "collection": col, "count": len(hits)} # --------------------------------------------------------------------------- # Delete points # --------------------------------------------------------------------------- @app.post("/delete") def delete_points(req: DeleteRequest): col = _col(req.collection) client.delete( collection_name=col, points_selector=req.ids, ) return {"deleted": req.ids, "collection": col} # --------------------------------------------------------------------------- # Get point by ID # --------------------------------------------------------------------------- @app.get("/points/{point_id}") def get_point(point_id: str, collection: Optional[str] = None): col = _col(collection) try: points = client.retrieve(collection_name=col, ids=[point_id], with_vectors=True) if not points: raise HTTPException(404, f"Point '{point_id}' not found") p = points[0] return { "id": p.id, "vector": p.vector, "metadata": p.payload, "collection": col, } except HTTPException: raise except Exception as e: raise HTTPException(404, str(e))