first commit
This commit is contained in:
17
qdrant/Dockerfile
Normal file
17
qdrant/Dockerfile
Normal file
@@ -0,0 +1,17 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY qdrant/requirements.txt /app/requirements.txt
|
||||
RUN pip install --no-cache-dir -r /app/requirements.txt
|
||||
|
||||
COPY qdrant /app
|
||||
COPY common /app/common
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
358
qdrant/main.py
Normal file
358
qdrant/main.py
Normal file
@@ -0,0 +1,358 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||||
from pydantic import BaseModel, Field
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import (
|
||||
Distance,
|
||||
PointStruct,
|
||||
VectorParams,
|
||||
Filter,
|
||||
FieldCondition,
|
||||
MatchValue,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant")
|
||||
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
|
||||
CLIP_URL = os.getenv("CLIP_URL", "http://clip:8000")
|
||||
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "images")
|
||||
VECTOR_DIM = int(os.getenv("VECTOR_DIM", "512"))
|
||||
|
||||
app = FastAPI(title="Skinbase Qdrant Service", version="1.0.0")
|
||||
client: QdrantClient = None # type: ignore[assignment]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Startup / shutdown
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup():
|
||||
global client
|
||||
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
|
||||
_ensure_collection()
|
||||
|
||||
|
||||
def _ensure_collection():
|
||||
"""Create the default collection if it does not exist yet."""
|
||||
collections = [c.name for c in client.get_collections().collections]
|
||||
if COLLECTION_NAME not in collections:
|
||||
client.create_collection(
|
||||
collection_name=COLLECTION_NAME,
|
||||
vectors_config=VectorParams(size=VECTOR_DIM, distance=Distance.COSINE),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / Response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class UpsertUrlRequest(BaseModel):
|
||||
url: str
|
||||
id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
collection: Optional[str] = None
|
||||
|
||||
|
||||
class UpsertVectorRequest(BaseModel):
|
||||
vector: List[float]
|
||||
id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
collection: Optional[str] = None
|
||||
|
||||
|
||||
class SearchUrlRequest(BaseModel):
|
||||
url: str
|
||||
limit: int = Field(default=5, ge=1, le=100)
|
||||
score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
collection: Optional[str] = None
|
||||
filter_metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SearchVectorRequest(BaseModel):
|
||||
vector: List[float]
|
||||
limit: int = Field(default=5, ge=1, le=100)
|
||||
score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
collection: Optional[str] = None
|
||||
filter_metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class DeleteRequest(BaseModel):
|
||||
ids: List[str]
|
||||
collection: Optional[str] = None
|
||||
|
||||
|
||||
class CollectionRequest(BaseModel):
|
||||
name: str
|
||||
vector_dim: int = Field(default=512, ge=1)
|
||||
distance: str = Field(default="cosine")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _col(name: Optional[str]) -> str:
|
||||
return name or COLLECTION_NAME
|
||||
|
||||
|
||||
async def _embed_url(url: str) -> List[float]:
|
||||
"""Call the CLIP service to get an image embedding."""
|
||||
async with httpx.AsyncClient(timeout=30) as http:
|
||||
r = await http.post(f"{CLIP_URL}/embed", json={"url": url})
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(502, f"CLIP /embed error: {r.status_code} {r.text[:200]}")
|
||||
return r.json()["vector"]
|
||||
|
||||
|
||||
async def _embed_bytes(data: bytes) -> List[float]:
|
||||
"""Call the CLIP service to embed uploaded file bytes."""
|
||||
async with httpx.AsyncClient(timeout=30) as http:
|
||||
files = {"file": ("image", data, "application/octet-stream")}
|
||||
r = await http.post(f"{CLIP_URL}/embed/file", files=files)
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(502, f"CLIP /embed/file error: {r.status_code} {r.text[:200]}")
|
||||
return r.json()["vector"]
|
||||
|
||||
|
||||
def _build_filter(metadata: Dict[str, Any]) -> Optional[Filter]:
|
||||
if not metadata:
|
||||
return None
|
||||
conditions = [
|
||||
FieldCondition(key=k, match=MatchValue(value=v))
|
||||
for k, v in metadata.items()
|
||||
]
|
||||
return Filter(must=conditions)
|
||||
|
||||
|
||||
def _point_id(raw: Optional[str]) -> str:
|
||||
return raw or uuid.uuid4().hex
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
try:
|
||||
info = client.get_collections()
|
||||
names = [c.name for c in info.collections]
|
||||
return {"status": "ok", "qdrant": QDRANT_HOST, "collections": names}
|
||||
except Exception as e:
|
||||
return {"status": "error", "detail": str(e)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Collection management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.post("/collections")
|
||||
def create_collection(req: CollectionRequest):
|
||||
dist_map = {"cosine": Distance.COSINE, "euclid": Distance.EUCLID, "dot": Distance.DOT}
|
||||
dist = dist_map.get(req.distance.lower())
|
||||
if dist is None:
|
||||
raise HTTPException(400, f"Unknown distance: {req.distance}. Use cosine, euclid, or dot.")
|
||||
|
||||
collections = [c.name for c in client.get_collections().collections]
|
||||
if req.name in collections:
|
||||
raise HTTPException(409, f"Collection '{req.name}' already exists")
|
||||
|
||||
client.create_collection(
|
||||
collection_name=req.name,
|
||||
vectors_config=VectorParams(size=req.vector_dim, distance=dist),
|
||||
)
|
||||
return {"created": req.name, "vector_dim": req.vector_dim, "distance": req.distance}
|
||||
|
||||
|
||||
@app.get("/collections")
|
||||
def list_collections():
|
||||
info = client.get_collections()
|
||||
return {"collections": [c.name for c in info.collections]}
|
||||
|
||||
|
||||
@app.get("/collections/{name}")
|
||||
def collection_info(name: str):
|
||||
try:
|
||||
info = client.get_collection(name)
|
||||
return {
|
||||
"name": name,
|
||||
"vectors_count": info.vectors_count,
|
||||
"points_count": info.points_count,
|
||||
"status": info.status.value if info.status else None,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(404, str(e))
|
||||
|
||||
|
||||
@app.delete("/collections/{name}")
|
||||
def delete_collection(name: str):
|
||||
client.delete_collection(name)
|
||||
return {"deleted": name}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Upsert endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.post("/upsert")
|
||||
async def upsert_url(req: UpsertUrlRequest):
|
||||
"""Embed an image by URL via CLIP, then store the vector in Qdrant."""
|
||||
vector = await _embed_url(req.url)
|
||||
pid = _point_id(req.id)
|
||||
payload = {**req.metadata, "source_url": req.url}
|
||||
col = _col(req.collection)
|
||||
|
||||
client.upsert(
|
||||
collection_name=col,
|
||||
points=[PointStruct(id=pid, vector=vector, payload=payload)],
|
||||
)
|
||||
return {"id": pid, "collection": col, "dim": len(vector)}
|
||||
|
||||
|
||||
@app.post("/upsert/file")
|
||||
async def upsert_file(
|
||||
file: UploadFile = File(...),
|
||||
id: Optional[str] = Form(None),
|
||||
collection: Optional[str] = Form(None),
|
||||
metadata_json: Optional[str] = Form(None),
|
||||
):
|
||||
"""Embed an uploaded image via CLIP, then store the vector in Qdrant."""
|
||||
import json
|
||||
|
||||
data = await file.read()
|
||||
vector = await _embed_bytes(data)
|
||||
pid = _point_id(id)
|
||||
|
||||
payload: Dict[str, Any] = {}
|
||||
if metadata_json:
|
||||
try:
|
||||
payload = json.loads(metadata_json)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(400, "metadata_json must be valid JSON")
|
||||
|
||||
col = _col(collection)
|
||||
client.upsert(
|
||||
collection_name=col,
|
||||
points=[PointStruct(id=pid, vector=vector, payload=payload)],
|
||||
)
|
||||
return {"id": pid, "collection": col, "dim": len(vector)}
|
||||
|
||||
|
||||
@app.post("/upsert/vector")
|
||||
def upsert_vector(req: UpsertVectorRequest):
|
||||
"""Store a pre-computed vector directly (skip CLIP embedding)."""
|
||||
pid = _point_id(req.id)
|
||||
col = _col(req.collection)
|
||||
client.upsert(
|
||||
collection_name=col,
|
||||
points=[PointStruct(id=pid, vector=req.vector, payload=req.metadata)],
|
||||
)
|
||||
return {"id": pid, "collection": col, "dim": len(req.vector)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Search endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.post("/search")
|
||||
async def search_url(req: SearchUrlRequest):
|
||||
"""Embed an image by URL via CLIP, then search Qdrant for similar vectors."""
|
||||
vector = await _embed_url(req.url)
|
||||
return _do_search(vector, req.limit, req.score_threshold, req.collection, req.filter_metadata)
|
||||
|
||||
|
||||
@app.post("/search/file")
|
||||
async def search_file(
|
||||
file: UploadFile = File(...),
|
||||
limit: int = Form(5),
|
||||
score_threshold: Optional[float] = Form(None),
|
||||
collection: Optional[str] = Form(None),
|
||||
):
|
||||
"""Embed an uploaded image via CLIP, then search Qdrant for similar vectors."""
|
||||
data = await file.read()
|
||||
vector = await _embed_bytes(data)
|
||||
return _do_search(vector, int(limit), score_threshold, collection, {})
|
||||
|
||||
|
||||
@app.post("/search/vector")
|
||||
def search_vector(req: SearchVectorRequest):
|
||||
"""Search Qdrant using a pre-computed vector."""
|
||||
return _do_search(req.vector, req.limit, req.score_threshold, req.collection, req.filter_metadata)
|
||||
|
||||
|
||||
def _do_search(
|
||||
vector: List[float],
|
||||
limit: int,
|
||||
score_threshold: Optional[float],
|
||||
collection: Optional[str],
|
||||
filter_metadata: Dict[str, Any],
|
||||
):
|
||||
col = _col(collection)
|
||||
qfilter = _build_filter(filter_metadata)
|
||||
|
||||
results = client.query_points(
|
||||
collection_name=col,
|
||||
query=vector,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_filter=qfilter,
|
||||
)
|
||||
|
||||
hits = []
|
||||
for point in results.points:
|
||||
hits.append({
|
||||
"id": point.id,
|
||||
"score": point.score,
|
||||
"metadata": point.payload,
|
||||
})
|
||||
|
||||
return {"results": hits, "collection": col, "count": len(hits)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Delete points
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.post("/delete")
|
||||
def delete_points(req: DeleteRequest):
|
||||
col = _col(req.collection)
|
||||
client.delete(
|
||||
collection_name=col,
|
||||
points_selector=req.ids,
|
||||
)
|
||||
return {"deleted": req.ids, "collection": col}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Get point by ID
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/points/{point_id}")
|
||||
def get_point(point_id: str, collection: Optional[str] = None):
|
||||
col = _col(collection)
|
||||
try:
|
||||
points = client.retrieve(collection_name=col, ids=[point_id], with_vectors=True)
|
||||
if not points:
|
||||
raise HTTPException(404, f"Point '{point_id}' not found")
|
||||
p = points[0]
|
||||
return {
|
||||
"id": p.id,
|
||||
"vector": p.vector,
|
||||
"metadata": p.payload,
|
||||
"collection": col,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(404, str(e))
|
||||
8
qdrant/requirements.txt
Normal file
8
qdrant/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
fastapi==0.115.5
|
||||
uvicorn[standard]==0.30.6
|
||||
python-multipart==0.0.9
|
||||
requests==2.32.3
|
||||
pillow==10.4.0
|
||||
qdrant-client==1.12.1
|
||||
httpx==0.27.2
|
||||
numpy
|
||||
Reference in New Issue
Block a user