from __future__ import annotations import os import uuid from typing import Any, Dict, List, Optional import httpx import numpy as np from fastapi import FastAPI, HTTPException, UploadFile, File, Form from pydantic import BaseModel, Field from qdrant_client import QdrantClient from qdrant_client.models import ( Distance, PointStruct, VectorParams, Filter, FieldCondition, MatchValue, ) # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant") QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333")) CLIP_URL = os.getenv("CLIP_URL", "http://clip:8000") COLLECTION_NAME = os.getenv("COLLECTION_NAME", "images") VECTOR_DIM = int(os.getenv("VECTOR_DIM", "512")) app = FastAPI(title="Skinbase Qdrant Service", version="1.0.0") client: QdrantClient = None # type: ignore[assignment] # --------------------------------------------------------------------------- # Startup / shutdown # --------------------------------------------------------------------------- @app.on_event("startup") def startup(): global client client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT) _ensure_collection() def _ensure_collection(): """Create the default collection if it does not exist yet.""" collections = [c.name for c in client.get_collections().collections] if COLLECTION_NAME not in collections: client.create_collection( collection_name=COLLECTION_NAME, vectors_config=VectorParams(size=VECTOR_DIM, distance=Distance.COSINE), ) # --------------------------------------------------------------------------- # Request / Response models # --------------------------------------------------------------------------- class UpsertUrlRequest(BaseModel): url: str id: Optional[str] = None metadata: Dict[str, Any] = Field(default_factory=dict) collection: Optional[str] = None class UpsertVectorRequest(BaseModel): vector: List[float] id: Optional[str] = None metadata: Dict[str, Any] = Field(default_factory=dict) collection: Optional[str] = None class SearchUrlRequest(BaseModel): url: str limit: int = Field(default=5, ge=1, le=100) score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0) collection: Optional[str] = None filter_metadata: Dict[str, Any] = Field(default_factory=dict) class SearchVectorRequest(BaseModel): vector: List[float] limit: int = Field(default=5, ge=1, le=100) score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0) collection: Optional[str] = None filter_metadata: Dict[str, Any] = Field(default_factory=dict) class DeleteRequest(BaseModel): ids: List[str] collection: Optional[str] = None class CollectionRequest(BaseModel): name: str vector_dim: int = Field(default=512, ge=1) distance: str = Field(default="cosine") # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _col(name: Optional[str]) -> str: return name or COLLECTION_NAME async def _embed_url(url: str) -> List[float]: """Call the CLIP service to get an image embedding.""" async with httpx.AsyncClient(timeout=30) as http: try: r = await http.post(f"{CLIP_URL}/embed", json={"url": url}) except httpx.RequestError as e: raise HTTPException(502, f"CLIP request failed: {str(e)}") if r.status_code >= 400: raise HTTPException(502, f"CLIP /embed error: {r.status_code} {r.text[:200]}") try: return r.json()["vector"] except Exception: raise HTTPException(502, f"CLIP /embed returned non-JSON: {r.status_code} {r.text[:200]}") async def _embed_bytes(data: bytes) -> List[float]: """Call the CLIP service to embed uploaded file bytes.""" async with httpx.AsyncClient(timeout=30) as http: files = {"file": ("image", data, "application/octet-stream")} try: r = await http.post(f"{CLIP_URL}/embed/file", files=files) except httpx.RequestError as e: raise HTTPException(502, f"CLIP request failed: {str(e)}") if r.status_code >= 400: raise HTTPException(502, f"CLIP /embed/file error: {r.status_code} {r.text[:200]}") try: return r.json()["vector"] except Exception: raise HTTPException(502, f"CLIP /embed/file returned non-JSON: {r.status_code} {r.text[:200]}") def _build_filter(metadata: Dict[str, Any]) -> Optional[Filter]: if not metadata: return None conditions = [ FieldCondition(key=k, match=MatchValue(value=v)) for k, v in metadata.items() ] return Filter(must=conditions) def _point_id(raw: Optional[str]) -> str: """Return a Qdrant-compatible point id. Qdrant accepts either an unsigned integer or a UUID string (with hyphens). If the provided `raw` value is an int or valid UUID we return it (as int or str). Otherwise we generate a new UUID string and the caller should store the original `raw` value in the point payload under `_original_id`. """ if not raw: return str(uuid.uuid4()) # allow integer ids try: return int(raw) except Exception: pass # allow UUID strings try: u = uuid.UUID(raw) return str(u) except Exception: # fallback: generate a UUID return str(uuid.uuid4()) # --------------------------------------------------------------------------- # Health # --------------------------------------------------------------------------- @app.get("/health") def health(): try: info = client.get_collections() names = [c.name for c in info.collections] return {"status": "ok", "qdrant": QDRANT_HOST, "collections": names} except Exception as e: return {"status": "error", "detail": str(e)} # --------------------------------------------------------------------------- # Collection management # --------------------------------------------------------------------------- @app.post("/collections") def create_collection(req: CollectionRequest): dist_map = {"cosine": Distance.COSINE, "euclid": Distance.EUCLID, "dot": Distance.DOT} dist = dist_map.get(req.distance.lower()) if dist is None: raise HTTPException(400, f"Unknown distance: {req.distance}. Use cosine, euclid, or dot.") collections = [c.name for c in client.get_collections().collections] if req.name in collections: raise HTTPException(409, f"Collection '{req.name}' already exists") client.create_collection( collection_name=req.name, vectors_config=VectorParams(size=req.vector_dim, distance=dist), ) return {"created": req.name, "vector_dim": req.vector_dim, "distance": req.distance} @app.get("/collections") def list_collections(): info = client.get_collections() return {"collections": [c.name for c in info.collections]} @app.get("/collections/{name}") def collection_info(name: str): try: info = client.get_collection(name) return { "name": name, "vectors_count": info.vectors_count, "points_count": info.points_count, "status": info.status.value if info.status else None, } except Exception as e: raise HTTPException(404, str(e)) @app.delete("/collections/{name}") def delete_collection(name: str): client.delete_collection(name) return {"deleted": name} # --------------------------------------------------------------------------- # Upsert endpoints # --------------------------------------------------------------------------- @app.post("/upsert") async def upsert_url(req: UpsertUrlRequest): """Embed an image by URL via CLIP, then store the vector in Qdrant.""" vector = await _embed_url(req.url) pid = _point_id(req.id) payload = {**req.metadata, "source_url": req.url} # preserve original user-provided id if it wasn't usable as a point id if req.id is not None and str(pid) != str(req.id): payload["_original_id"] = req.id col = _col(req.collection) try: client.upsert( collection_name=col, points=[PointStruct(id=pid, vector=vector, payload=payload)], ) except Exception as e: raise HTTPException(500, str(e)) return {"id": pid, "collection": col, "dim": len(vector)} @app.post("/upsert/file") async def upsert_file( file: UploadFile = File(...), id: Optional[str] = Form(None), collection: Optional[str] = Form(None), metadata_json: Optional[str] = Form(None), ): """Embed an uploaded image via CLIP, then store the vector in Qdrant.""" import json data = await file.read() vector = await _embed_bytes(data) pid = _point_id(id) payload: Dict[str, Any] = {} if metadata_json: try: payload = json.loads(metadata_json) except json.JSONDecodeError: raise HTTPException(400, "metadata_json must be valid JSON") # preserve original user-provided id if it wasn't usable as a point id if id is not None and str(pid) != str(id): payload["_original_id"] = id col = _col(collection) try: client.upsert( collection_name=col, points=[PointStruct(id=pid, vector=vector, payload=payload)], ) except Exception as e: raise HTTPException(500, str(e)) return {"id": pid, "collection": col, "dim": len(vector)} @app.post("/upsert/vector") def upsert_vector(req: UpsertVectorRequest): """Store a pre-computed vector directly (skip CLIP embedding).""" pid = _point_id(req.id) col = _col(req.collection) payload = dict(req.metadata or {}) if req.id is not None and str(pid) != str(req.id): payload["_original_id"] = req.id try: client.upsert( collection_name=col, points=[PointStruct(id=pid, vector=req.vector, payload=payload)], ) except Exception as e: raise HTTPException(500, str(e)) return {"id": pid, "collection": col, "dim": len(req.vector)} # --------------------------------------------------------------------------- # Search endpoints # --------------------------------------------------------------------------- @app.post("/search") async def search_url(req: SearchUrlRequest): """Embed an image by URL via CLIP, then search Qdrant for similar vectors.""" vector = await _embed_url(req.url) return _do_search(vector, req.limit, req.score_threshold, req.collection, req.filter_metadata) @app.post("/search/file") async def search_file( file: UploadFile = File(...), limit: int = Form(5), score_threshold: Optional[float] = Form(None), collection: Optional[str] = Form(None), ): """Embed an uploaded image via CLIP, then search Qdrant for similar vectors.""" data = await file.read() vector = await _embed_bytes(data) return _do_search(vector, int(limit), score_threshold, collection, {}) @app.post("/search/vector") def search_vector(req: SearchVectorRequest): """Search Qdrant using a pre-computed vector.""" return _do_search(req.vector, req.limit, req.score_threshold, req.collection, req.filter_metadata) def _do_search( vector: List[float], limit: int, score_threshold: Optional[float], collection: Optional[str], filter_metadata: Dict[str, Any], ): col = _col(collection) qfilter = _build_filter(filter_metadata) results = client.query_points( collection_name=col, query=vector, limit=limit, score_threshold=score_threshold, query_filter=qfilter, ) hits = [] for point in results.points: hits.append({ "id": point.id, "score": point.score, "metadata": point.payload, }) return {"results": hits, "collection": col, "count": len(hits)} # --------------------------------------------------------------------------- # Delete points # --------------------------------------------------------------------------- @app.post("/delete") def delete_points(req: DeleteRequest): col = _col(req.collection) client.delete( collection_name=col, points_selector=req.ids, ) return {"deleted": req.ids, "collection": col} # --------------------------------------------------------------------------- # Get point by ID # --------------------------------------------------------------------------- @app.get("/points/{point_id}") def get_point(point_id: str, collection: Optional[str] = None): col = _col(collection) try: points = client.retrieve(collection_name=col, ids=[point_id], with_vectors=True) if not points: raise HTTPException(404, f"Point '{point_id}' not found") p = points[0] return { "id": p.id, "vector": p.vector, "metadata": p.payload, "collection": col, } except HTTPException: raise except Exception as e: raise HTTPException(404, str(e))