first commit
This commit is contained in:
68
.gitignore
vendored
Normal file
68
.gitignore
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
# Python caches
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# Distribution / packaging
|
||||
build/
|
||||
dist/
|
||||
*.egg-info/
|
||||
pip-wheel-metadata/
|
||||
|
||||
# Virtual environments
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
.env
|
||||
.env.*
|
||||
|
||||
# Pytest / tooling
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
coverage.xml
|
||||
|
||||
# IDEs and editors
|
||||
.vscode/
|
||||
.idea/
|
||||
*.sublime-workspace
|
||||
*.sublime-project
|
||||
|
||||
# Jupyter
|
||||
.ipynb_checkpoints
|
||||
|
||||
# OS files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Docker
|
||||
docker-compose.override.yml
|
||||
*.log
|
||||
|
||||
# Qdrant/DB data (if accidentally checked in)
|
||||
qdrant_data/
|
||||
|
||||
# Model weights & caches
|
||||
*.pt
|
||||
*.pth
|
||||
*.bin
|
||||
*.ckpt
|
||||
|
||||
# Numpy arrays
|
||||
*.npy
|
||||
|
||||
# Logs
|
||||
logs/
|
||||
*.log
|
||||
|
||||
# Misc
|
||||
*.sqlite3
|
||||
|
||||
# Local secrets
|
||||
*.pem
|
||||
*.key
|
||||
secrets.json
|
||||
|
||||
# End of file
|
||||
16
Dockerfile
Normal file
16
Dockerfile
Normal file
@@ -0,0 +1,16 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt update && apt install -y \
|
||||
ffmpeg \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
&& rm -rf /var/lib/apt
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
138
README.md
Normal file
138
README.md
Normal file
@@ -0,0 +1,138 @@
|
||||
# Skinbase Vision Stack (CLIP + BLIP + YOLO + Qdrant) – Dockerized FastAPI
|
||||
|
||||
This repository provides **four standalone vision services** (CLIP / BLIP / YOLO / Qdrant)
|
||||
and a **Gateway API** that can call them individually or together.
|
||||
|
||||
## Services & Ports
|
||||
|
||||
- `gateway` (exposed): `https://vision.klevze.net`
|
||||
- `clip`: internal only
|
||||
- `blip`: internal only
|
||||
- `yolo`: internal only
|
||||
- `qdrant`: vector DB (port `6333` exposed for direct access)
|
||||
- `qdrant-svc`: internal Qdrant API wrapper
|
||||
|
||||
## Run
|
||||
|
||||
```bash
|
||||
docker compose up -d --build
|
||||
```
|
||||
|
||||
## Health
|
||||
|
||||
```bash
|
||||
curl https://vision.klevze.net/health
|
||||
```
|
||||
|
||||
## Universal analyze (ALL)
|
||||
|
||||
### With URL
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/all \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"url":"https://files.skinbase.org/img/aa/bb/cc/md.webp","limit":5}'
|
||||
```
|
||||
|
||||
### With file upload (multipart)
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/all/file \
|
||||
-F "file=@/path/to/image.webp" \
|
||||
-F "limit=5"
|
||||
```
|
||||
|
||||
## Individual services (via gateway)
|
||||
|
||||
### CLIP tags
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/clip -H "Content-Type: application/json" \
|
||||
-d '{"url":"https://files.skinbase.org/img/aa/bb/cc/md.webp","limit":5}'
|
||||
```
|
||||
|
||||
### CLIP tags (file)
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/clip/file \
|
||||
-F "file=@/path/to/image.webp" \
|
||||
-F "limit=5"
|
||||
```
|
||||
|
||||
### BLIP caption
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/blip -H "Content-Type: application/json" \
|
||||
-d '{"url":"https://files.skinbase.org/img/aa/bb/cc/md.webp","variants":3}'
|
||||
```
|
||||
|
||||
### BLIP caption (file)
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/blip/file \
|
||||
-F "file=@/path/to/image.webp" \
|
||||
-F "variants=3" \
|
||||
-F "max_length=60"
|
||||
```
|
||||
|
||||
### YOLO detect
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/yolo -H "Content-Type: application/json" \
|
||||
-d '{"url":"https://files.skinbase.org/img/aa/bb/cc/md.webp","conf":0.25}'
|
||||
```
|
||||
|
||||
### YOLO detect (file)
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/yolo/file \
|
||||
-F "file=@/path/to/image.webp" \
|
||||
-F "conf=0.25"
|
||||
```
|
||||
|
||||
## Vector DB (Qdrant) via gateway
|
||||
|
||||
### Store image embedding by URL
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/vectors/upsert \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"url":"https://files.skinbase.org/img/aa/bb/cc/md.webp","id":"img-001","metadata":{"category":"wallpaper"}}'
|
||||
```
|
||||
|
||||
### Store image embedding by file upload
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/vectors/upsert/file \
|
||||
-F "file=@/path/to/image.webp" \
|
||||
-F 'id=img-002' \
|
||||
-F 'metadata_json={"category":"photo"}'
|
||||
```
|
||||
|
||||
### Search similar images by URL
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/vectors/search \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"url":"https://files.skinbase.org/img/aa/bb/cc/md.webp","limit":5}'
|
||||
```
|
||||
|
||||
### Search similar images by file upload
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/vectors/search/file \
|
||||
-F "file=@/path/to/image.webp" \
|
||||
-F "limit=5"
|
||||
```
|
||||
|
||||
### List collections
|
||||
```bash
|
||||
curl https://vision.klevze.net/vectors/collections
|
||||
```
|
||||
|
||||
### Get collection info
|
||||
```bash
|
||||
curl https://vision.klevze.net/vectors/collections/images
|
||||
```
|
||||
|
||||
### Delete points
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/vectors/delete \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"ids":["img-001","img-002"]}'
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- This is a **starter scaffold**. Models are loaded at service startup.
|
||||
- Qdrant data is persisted via a Docker volume (`qdrant_data`).
|
||||
- For production: add auth, rate limits, and restrict gateway exposure (private network).
|
||||
- GPU: you can add NVIDIA runtime later (compose profiles) if needed.
|
||||
303
USAGE.md
Normal file
303
USAGE.md
Normal file
@@ -0,0 +1,303 @@
|
||||
# Skinbase Vision Stack — Usage Guide
|
||||
|
||||
This document explains how to run and use the Skinbase Vision Stack (Gateway + CLIP, BLIP, YOLO, Qdrant services).
|
||||
|
||||
## Overview
|
||||
|
||||
- Services: `gateway`, `clip`, `blip`, `yolo`, `qdrant`, `qdrant-svc` (FastAPI each, except `qdrant` which is the official Qdrant DB).
|
||||
- Gateway is the public API endpoint; the other services are internal.
|
||||
|
||||
## Model overview
|
||||
|
||||
- **CLIP**: Contrastive Language–Image Pretraining — maps images and text into a shared embedding space. Used for zero-shot image tagging, similarity search, and returning ranked tags with confidence scores.
|
||||
|
||||
- **BLIP**: Bootstrapping Language-Image Pre-training — a vision–language model for image captioning and multimodal generation. BLIP produces human-readable captions (multiple `variants` supported) and can be tuned with `max_length`.
|
||||
|
||||
- **YOLO**: You Only Look Once — a family of real-time object-detection models. YOLO returns detected objects with `class`, `confidence`, and `bbox` (bounding box coordinates); use `conf` to filter low-confidence detections.
|
||||
|
||||
- **Qdrant**: High-performance vector similarity search engine. Stores CLIP image embeddings and enables reverse image search (find similar images). The `qdrant-svc` wrapper auto-embeds images via CLIP before upserting.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Docker Desktop (with `docker compose`) or a Docker environment.
|
||||
- Recommended: at least 8GB RAM for CPU-only; more for model memory or GPU use.
|
||||
|
||||
## Start the stack
|
||||
|
||||
Run from repository root:
|
||||
|
||||
```bash
|
||||
docker compose up -d --build
|
||||
```
|
||||
|
||||
Stop:
|
||||
|
||||
```bash
|
||||
docker compose down
|
||||
```
|
||||
|
||||
View logs:
|
||||
|
||||
```bash
|
||||
docker compose logs -f
|
||||
docker compose logs -f gateway
|
||||
```
|
||||
|
||||
## Health
|
||||
|
||||
Check the gateway health endpoint:
|
||||
|
||||
```bash
|
||||
curl https://vision.klevze.net/health
|
||||
```
|
||||
|
||||
## Universal analyze (ALL)
|
||||
|
||||
Analyze an image by URL (gateway aggregates CLIP, BLIP, YOLO):
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/all \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"url":"https://files.skinbase.org/img/aa/bb/cc/md.webp","limit":5}'
|
||||
```
|
||||
|
||||
File upload (multipart):
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/all/file \
|
||||
-F "file=@/path/to/image.webp" \
|
||||
-F "limit=5"
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- `limit`: optional integer to limit returned tag/caption items.
|
||||
|
||||
## Individual services (via gateway)
|
||||
|
||||
These endpoints call the specific service through the gateway.
|
||||
|
||||
### CLIP — tags
|
||||
|
||||
URL request:
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/clip \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"url":"https://files.skinbase.org/img/aa/bb/cc/md.webp","limit":5}'
|
||||
```
|
||||
|
||||
File upload:
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/clip/file \
|
||||
-F "file=@/path/to/image.webp" \
|
||||
-F "limit=5"
|
||||
```
|
||||
|
||||
Return: JSON list of tags with confidence scores.
|
||||
|
||||
### BLIP — captioning
|
||||
|
||||
URL request:
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/blip \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"url":"https://files.skinbase.org/img/aa/bb/cc/md.webp","variants":3}'
|
||||
```
|
||||
|
||||
File upload:
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/blip/file \
|
||||
-F "file=@/path/to/image.webp" \
|
||||
-F "variants=3" \
|
||||
-F "max_length=60"
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- `variants`: number of caption variants to return.
|
||||
- `max_length`: optional maximum caption length.
|
||||
|
||||
Return: one or more caption strings (optionally with scores).
|
||||
|
||||
### YOLO — object detection
|
||||
|
||||
URL request:
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/yolo \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"url":"https://files.skinbase.org/img/aa/bb/cc/md.webp","conf":0.25}'
|
||||
```
|
||||
|
||||
File upload:
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/yolo/file \
|
||||
-F "file=@/path/to/image.webp" \
|
||||
-F "conf=0.25"
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- `conf`: confidence threshold (0.0–1.0).
|
||||
|
||||
Return: detected objects with `class`, `confidence`, and `bbox` (bounding box coordinates).
|
||||
|
||||
### Qdrant — vector storage & similarity search
|
||||
|
||||
The Qdrant integration lets you store image embeddings and find visually similar images. Embeddings are generated automatically by the CLIP service.
|
||||
|
||||
#### Upsert (store) an image by URL
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/vectors/upsert \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"url":"https://files.skinbase.org/img/aa/bb/cc/md.webp","id":"img-001","metadata":{"category":"wallpaper","source":"upload"}}'
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- `url` (required): image URL to embed and store.
|
||||
- `id` (optional): custom string ID for the point; auto-generated if omitted.
|
||||
- `metadata` (optional): arbitrary key-value payload stored alongside the vector.
|
||||
- `collection` (optional): target collection name (defaults to `images`).
|
||||
|
||||
#### Upsert by file upload
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/vectors/upsert/file \
|
||||
-F "file=@/path/to/image.webp" \
|
||||
-F 'id=img-002' \
|
||||
-F 'metadata_json={"category":"photo"}'
|
||||
```
|
||||
|
||||
#### Upsert a pre-computed vector
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/vectors/upsert/vector \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"vector":[0.1,0.2,...],"id":"img-003","metadata":{"custom":"data"}}'
|
||||
```
|
||||
|
||||
#### Search similar images by URL
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/vectors/search \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"url":"https://files.skinbase.org/img/aa/bb/cc/md.webp","limit":5}'
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- `url` (required): query image URL.
|
||||
- `limit` (optional, default 5): number of results.
|
||||
- `score_threshold` (optional): minimum cosine similarity (0.0–1.0).
|
||||
- `filter_metadata` (optional): filter results by metadata, e.g. `{"category":"wallpaper"}`.
|
||||
- `collection` (optional): collection to search.
|
||||
|
||||
Return: list of `{"id", "score", "metadata"}` sorted by similarity.
|
||||
|
||||
#### Search by file upload
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/vectors/search/file \
|
||||
-F "file=@/path/to/image.webp" \
|
||||
-F "limit=5"
|
||||
```
|
||||
|
||||
#### Search by pre-computed vector
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/vectors/search/vector \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"vector":[0.1,0.2,...],"limit":5}'
|
||||
```
|
||||
|
||||
#### Collection management
|
||||
|
||||
List all collections:
|
||||
```bash
|
||||
curl https://vision.klevze.net/vectors/collections
|
||||
```
|
||||
|
||||
Get collection info:
|
||||
```bash
|
||||
curl https://vision.klevze.net/vectors/collections/images
|
||||
```
|
||||
|
||||
Create a custom collection:
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/vectors/collections \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"name":"my_collection","vector_dim":512,"distance":"cosine"}'
|
||||
```
|
||||
|
||||
Delete a collection:
|
||||
```bash
|
||||
curl -X DELETE https://vision.klevze.net/vectors/collections/my_collection
|
||||
```
|
||||
|
||||
#### Delete points
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/vectors/delete \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"ids":["img-001","img-002"]}'
|
||||
```
|
||||
|
||||
#### Get a point by ID
|
||||
|
||||
```bash
|
||||
curl https://vision.klevze.net/vectors/points/img-001
|
||||
```
|
||||
|
||||
## Request/Response notes
|
||||
|
||||
- For URL requests use `Content-Type: application/json`.
|
||||
- For uploads use `multipart/form-data` with a `file` field.
|
||||
- The gateway aggregates and normalizes outputs for `/analyze/all`.
|
||||
|
||||
## Running a single service
|
||||
|
||||
To run only one service via docker compose:
|
||||
|
||||
```bash
|
||||
docker compose up -d --build clip
|
||||
```
|
||||
|
||||
Or run locally (Python env) from the service folder:
|
||||
|
||||
```bash
|
||||
# inside clip/ or blip/ or yolo/
|
||||
uvicorn main:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
## Production tips
|
||||
|
||||
- Add authentication (API keys or OAuth) at the gateway.
|
||||
- Add rate-limiting and per-client quotas.
|
||||
- Keep model services on an internal Docker network.
|
||||
- For GPU: enable NVIDIA runtime and update service Dockerfiles / compose profiles.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- Service fails to start: check `docker compose logs <service>` for model load errors.
|
||||
- High memory / OOM: increase host memory or reduce model footprint; consider GPUs.
|
||||
- Slow startup: model weights load on service startup — expect extra time.
|
||||
|
||||
## Extending
|
||||
|
||||
- Swap or update models in each service by editing that service's `main.py`.
|
||||
- Add request validation, timeouts, and retries in the gateway to improve robustness.
|
||||
|
||||
## Files of interest
|
||||
|
||||
- `docker-compose.yml` — composition and service definitions.
|
||||
- `gateway/` — gateway FastAPI server.
|
||||
- `clip/`, `blip/`, `yolo/` — service implementations and Dockerfiles.
|
||||
- `qdrant/` — Qdrant API wrapper service (FastAPI).
|
||||
- `common/` — shared helpers (e.g., image I/O).
|
||||
|
||||
---
|
||||
|
||||
If you want, I can merge these same contents into the project `README.md`,
|
||||
create a Postman collection, or add example response schemas for each endpoint.
|
||||
17
blip/Dockerfile
Normal file
17
blip/Dockerfile
Normal file
@@ -0,0 +1,17 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY blip/requirements.txt /app/requirements.txt
|
||||
RUN pip install --no-cache-dir -r /app/requirements.txt
|
||||
|
||||
COPY blip /app
|
||||
COPY common /app/common
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
78
blip/main.py
Normal file
78
blip/main.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import BlipProcessor, BlipForConditionalGeneration
|
||||
|
||||
from common.image_io import fetch_url_bytes, bytes_to_pil, ImageLoadError
|
||||
|
||||
BLIP_MODEL = os.getenv("BLIP_MODEL", "Salesforce/blip-image-captioning-base")
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
app = FastAPI(title="Skinbase BLIP Service", version="1.0.0")
|
||||
|
||||
processor = BlipProcessor.from_pretrained(BLIP_MODEL)
|
||||
model = BlipForConditionalGeneration.from_pretrained(BLIP_MODEL).to(DEVICE).eval()
|
||||
|
||||
|
||||
class CaptionRequest(BaseModel):
|
||||
url: Optional[str] = None
|
||||
variants: int = Field(default=3, ge=0, le=10)
|
||||
max_length: int = Field(default=60, ge=10, le=200)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "ok", "device": DEVICE, "model": BLIP_MODEL}
|
||||
|
||||
|
||||
def _caption_bytes(data: bytes, variants: int, max_length: int):
|
||||
img = bytes_to_pil(data)
|
||||
inputs = processor(img, return_tensors="pt").to(DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model.generate(**inputs, max_length=max_length, num_beams=5)
|
||||
base_caption = processor.decode(out[0], skip_special_tokens=True)
|
||||
|
||||
variant_list = []
|
||||
# generate additional variants using sampling (best-effort uniqueness)
|
||||
for _ in range(max(0, variants - 1)):
|
||||
with torch.no_grad():
|
||||
out2 = model.generate(
|
||||
**inputs,
|
||||
max_length=max_length,
|
||||
do_sample=True,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
temperature=0.9,
|
||||
)
|
||||
text = processor.decode(out2[0], skip_special_tokens=True)
|
||||
if text != base_caption and text not in variant_list:
|
||||
variant_list.append(text)
|
||||
|
||||
return {"caption": base_caption, "variants": variant_list, "model": BLIP_MODEL}
|
||||
|
||||
|
||||
@app.post("/caption")
|
||||
def caption(req: CaptionRequest):
|
||||
if not req.url:
|
||||
raise HTTPException(400, "url is required")
|
||||
try:
|
||||
data = fetch_url_bytes(req.url)
|
||||
return _caption_bytes(data, req.variants, req.max_length)
|
||||
except ImageLoadError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
|
||||
|
||||
@app.post("/caption/file")
|
||||
async def caption_file(
|
||||
file: UploadFile = File(...),
|
||||
variants: int = Form(3),
|
||||
max_length: int = Form(60),
|
||||
):
|
||||
data = await file.read()
|
||||
return _caption_bytes(data, int(variants), int(max_length))
|
||||
8
blip/requirements.txt
Normal file
8
blip/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
fastapi==0.115.5
|
||||
uvicorn[standard]==0.30.6
|
||||
python-multipart==0.0.9
|
||||
requests==2.32.3
|
||||
pillow==10.4.0
|
||||
torch==2.4.1
|
||||
torchvision==0.19.1
|
||||
transformers==4.44.2
|
||||
17
clip/Dockerfile
Normal file
17
clip/Dockerfile
Normal file
@@ -0,0 +1,17 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY clip/requirements.txt /app/requirements.txt
|
||||
RUN pip install --no-cache-dir -r /app/requirements.txt
|
||||
|
||||
COPY clip /app
|
||||
COPY common /app/common
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
161
clip/main.py
Normal file
161
clip/main.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import open_clip
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
|
||||
from common.image_io import fetch_url_bytes, bytes_to_pil, ImageLoadError
|
||||
|
||||
MODEL_NAME = os.getenv("MODEL_NAME", "ViT-B-32")
|
||||
MODEL_PRETRAINED = os.getenv("MODEL_PRETRAINED", "openai")
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# Starter vocab (replace with DB-driven vocab later)
|
||||
TAGS: List[str] = [
|
||||
"wallpaper", "4k wallpaper", "8k wallpaper",
|
||||
"cyberpunk", "neon", "city", "night", "sci-fi", "space",
|
||||
"fantasy", "anime", "digital art", "abstract", "minimal",
|
||||
"landscape", "nature", "mountains", "forest", "ocean", "sunset",
|
||||
"photography", "portrait", "architecture", "cars", "gaming",
|
||||
]
|
||||
|
||||
app = FastAPI(title="Skinbase CLIP Service", version="1.0.0")
|
||||
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(MODEL_NAME, pretrained=MODEL_PRETRAINED)
|
||||
tokenizer = open_clip.get_tokenizer(MODEL_NAME)
|
||||
model = model.to(DEVICE).eval()
|
||||
|
||||
|
||||
class AnalyzeRequest(BaseModel):
|
||||
url: Optional[str] = None
|
||||
limit: int = Field(default=5, ge=1, le=50)
|
||||
threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class EmbedRequest(BaseModel):
|
||||
url: Optional[str] = None
|
||||
backend: Optional[str] = Field(default="openclip", regex="^(openclip|hf)$")
|
||||
model: Optional[str] = None
|
||||
pretrained: Optional[str] = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "ok", "device": DEVICE, "model": MODEL_NAME, "pretrained": MODEL_PRETRAINED}
|
||||
|
||||
|
||||
def _analyze_image_bytes(data: bytes, limit: int, threshold: Optional[float]):
|
||||
img = bytes_to_pil(data)
|
||||
image_input = preprocess(img).unsqueeze(0).to(DEVICE)
|
||||
text = tokenizer(TAGS).to(DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(image_input)
|
||||
text_features = model.encode_text(text)
|
||||
|
||||
# Normalize so dot product approximates cosine similarity
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
logits = (image_features @ text_features.T)
|
||||
probs = logits.softmax(dim=-1)
|
||||
|
||||
topk = probs[0].topk(min(limit, len(TAGS)))
|
||||
|
||||
results = []
|
||||
for score, idx in zip(topk.values, topk.indices):
|
||||
conf = float(score)
|
||||
if threshold is not None and conf < float(threshold):
|
||||
continue
|
||||
results.append({"tag": TAGS[int(idx)], "confidence": conf})
|
||||
|
||||
return {"tags": results, "model": MODEL_NAME, "dim": int(text_features.shape[-1])}
|
||||
|
||||
|
||||
def _embed_image_bytes(data: bytes, backend: str = "openclip", model_name: Optional[str] = None, pretrained: Optional[str] = None):
|
||||
img = bytes_to_pil(data)
|
||||
|
||||
if backend == "openclip":
|
||||
# prefer already-loaded model when matching global config
|
||||
use_model_name = model_name or MODEL_NAME
|
||||
use_pretrained = pretrained or MODEL_PRETRAINED
|
||||
if use_model_name == MODEL_NAME and use_pretrained == MODEL_PRETRAINED:
|
||||
_model = model
|
||||
_preprocess = preprocess
|
||||
device = DEVICE
|
||||
else:
|
||||
import open_clip as _oc
|
||||
_model, _, _preprocess = _oc.create_model_and_transforms(use_model_name, pretrained=use_pretrained)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
_model = _model.to(device).eval()
|
||||
|
||||
image_input = _preprocess(img).unsqueeze(0).to(device)
|
||||
with torch.no_grad():
|
||||
image_features = _model.encode_image(image_input)
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
vec = image_features.cpu().numpy()[0]
|
||||
|
||||
else:
|
||||
# HuggingFace CLIP backend
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
|
||||
hf_model_name = model_name or "openai/clip-vit-base-patch32"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
hf_model = CLIPModel.from_pretrained(hf_model_name).to(device).eval()
|
||||
processor = CLIPProcessor.from_pretrained(hf_model_name)
|
||||
inputs = processor(images=img, return_tensors="pt")
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
with torch.no_grad():
|
||||
feats = hf_model.get_image_features(**inputs)
|
||||
feats = feats / feats.norm(dim=-1, keepdim=True)
|
||||
vec = feats.cpu().numpy()[0]
|
||||
|
||||
return {"vector": vec.tolist(), "dim": int(np.asarray(vec).shape[-1]), "backend": backend, "model": model_name or (MODEL_NAME if backend == "openclip" else None)}
|
||||
|
||||
|
||||
@app.post("/analyze")
|
||||
def analyze(req: AnalyzeRequest):
|
||||
if not req.url:
|
||||
raise HTTPException(400, "url is required")
|
||||
try:
|
||||
data = fetch_url_bytes(req.url)
|
||||
return _analyze_image_bytes(data, req.limit, req.threshold)
|
||||
except ImageLoadError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
|
||||
|
||||
@app.post("/analyze/file")
|
||||
async def analyze_file(
|
||||
file: UploadFile = File(...),
|
||||
limit: int = Form(5),
|
||||
threshold: Optional[float] = Form(None),
|
||||
):
|
||||
data = await file.read()
|
||||
return _analyze_image_bytes(data, int(limit), threshold)
|
||||
|
||||
|
||||
@app.post("/embed")
|
||||
def embed(req: EmbedRequest):
|
||||
if not req.url:
|
||||
raise HTTPException(400, "url is required")
|
||||
try:
|
||||
data = fetch_url_bytes(req.url)
|
||||
return _embed_image_bytes(data, backend=req.backend, model_name=req.model, pretrained=req.pretrained)
|
||||
except ImageLoadError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
|
||||
|
||||
@app.post("/embed/file")
|
||||
async def embed_file(
|
||||
file: UploadFile = File(...),
|
||||
backend: str = Form("openclip"),
|
||||
model: Optional[str] = Form(None),
|
||||
pretrained: Optional[str] = Form(None),
|
||||
):
|
||||
data = await file.read()
|
||||
return _embed_image_bytes(data, backend=backend, model_name=model, pretrained=pretrained)
|
||||
10
clip/requirements.txt
Normal file
10
clip/requirements.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
fastapi==0.115.5
|
||||
uvicorn[standard]==0.30.6
|
||||
python-multipart==0.0.9
|
||||
requests==2.32.3
|
||||
pillow==10.4.0
|
||||
torch==2.4.1
|
||||
torchvision==0.19.1
|
||||
open_clip_torch==2.26.1
|
||||
transformers
|
||||
numpy
|
||||
104
clip/vectorize.py
Normal file
104
clip/vectorize.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from common.image_io import fetch_url_bytes, bytes_to_pil, ImageLoadError
|
||||
|
||||
try:
|
||||
import open_clip
|
||||
except Exception:
|
||||
open_clip = None
|
||||
|
||||
try:
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
except Exception:
|
||||
CLIPModel = None
|
||||
CLIPProcessor = None
|
||||
|
||||
|
||||
def load_openclip(model_name: str = "ViT-B-32", pretrained: str = "openai") -> Tuple:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if open_clip is None:
|
||||
raise RuntimeError("open_clip is not installed")
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
|
||||
model = model.to(device).eval()
|
||||
return model, preprocess, device
|
||||
|
||||
|
||||
def embed_openclip(model, preprocess, device, pil_image: Image.Image) -> np.ndarray:
|
||||
image_input = preprocess(pil_image).unsqueeze(0).to(device)
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(image_input)
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
return image_features.cpu().numpy()[0]
|
||||
|
||||
|
||||
def load_hf_clip(model_name: str = "openai/clip-vit-base-patch32") -> Tuple:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if CLIPModel is None or CLIPProcessor is None:
|
||||
raise RuntimeError("transformers (CLIP) is not installed")
|
||||
model = CLIPModel.from_pretrained(model_name).to(device).eval()
|
||||
processor = CLIPProcessor.from_pretrained(model_name)
|
||||
return model, processor, device
|
||||
|
||||
|
||||
def embed_hf_clip(model, processor, device, pil_image: Image.Image) -> np.ndarray:
|
||||
inputs = processor(images=pil_image, return_tensors="pt")
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
with torch.no_grad():
|
||||
feats = model.get_image_features(**inputs)
|
||||
feats = feats / feats.norm(dim=-1, keepdim=True)
|
||||
return feats.cpu().numpy()[0]
|
||||
|
||||
|
||||
def load_image(path_or_url: str) -> Image.Image:
|
||||
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
||||
data = fetch_url_bytes(path_or_url)
|
||||
return bytes_to_pil(data)
|
||||
else:
|
||||
return Image.open(path_or_url).convert("RGB")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Vectorize an image using CLIP (open_clip or HuggingFace)")
|
||||
parser.add_argument("input", help="Path to image file or URL")
|
||||
parser.add_argument("--backend", choices=("openclip", "hf"), default="openclip")
|
||||
parser.add_argument("--model", default=None, help="Model name (backend-specific)")
|
||||
parser.add_argument("--pretrained", default="openai", help="open_clip pretrained source (openclip backend)")
|
||||
parser.add_argument("--out", default=None, help="Output .npy path (defaults to stdout)")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
img = load_image(args.input)
|
||||
except ImageLoadError as e:
|
||||
raise SystemExit(f"Failed to load image: {e}")
|
||||
|
||||
if args.backend == "openclip":
|
||||
model_name = args.model or os.getenv("MODEL_NAME", "ViT-B-32")
|
||||
pretrained = args.pretrained
|
||||
model, preprocess, device = load_openclip(model_name, pretrained=pretrained)
|
||||
vec = embed_openclip(model, preprocess, device, img)
|
||||
else:
|
||||
model_name = args.model or "openai/clip-vit-base-patch32"
|
||||
model, processor, device = load_hf_clip(model_name)
|
||||
vec = embed_hf_clip(model, processor, device, img)
|
||||
|
||||
vec = np.asarray(vec, dtype=np.float32)
|
||||
|
||||
if args.out:
|
||||
np.save(args.out, vec)
|
||||
print(f"Saved vector shape={vec.shape} to {args.out}")
|
||||
else:
|
||||
# Print a short summary and the vector length. Full vector to stdout can be large.
|
||||
print(f"vector_shape={vec.shape}")
|
||||
print(np.array2string(vec, precision=6, separator=", "))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
common/__init__.py
Normal file
0
common/__init__.py
Normal file
35
common/image_io.py
Normal file
35
common/image_io.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from typing import Optional, Tuple
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
DEFAULT_MAX_BYTES = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
class ImageLoadError(Exception):
|
||||
pass
|
||||
|
||||
def fetch_url_bytes(url: str, timeout: float = 10.0, max_bytes: int = DEFAULT_MAX_BYTES) -> bytes:
|
||||
try:
|
||||
with requests.get(url, stream=True, timeout=timeout) as r:
|
||||
r.raise_for_status()
|
||||
buf = io.BytesIO()
|
||||
total = 0
|
||||
for chunk in r.iter_content(chunk_size=1024 * 64):
|
||||
if not chunk:
|
||||
continue
|
||||
total += len(chunk)
|
||||
if total > max_bytes:
|
||||
raise ImageLoadError(f"Image exceeds max_bytes={max_bytes}")
|
||||
buf.write(chunk)
|
||||
return buf.getvalue()
|
||||
except Exception as e:
|
||||
raise ImageLoadError(f"Cannot fetch image url: {e}") from e
|
||||
|
||||
def bytes_to_pil(data: bytes) -> Image.Image:
|
||||
try:
|
||||
img = Image.open(io.BytesIO(data)).convert("RGB")
|
||||
return img
|
||||
except Exception as e:
|
||||
raise ImageLoadError(f"Cannot decode image: {e}") from e
|
||||
68
docker-compose.yml
Normal file
68
docker-compose.yml
Normal file
@@ -0,0 +1,68 @@
|
||||
services:
|
||||
gateway:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: gateway/Dockerfile
|
||||
ports:
|
||||
- "8003:8000"
|
||||
environment:
|
||||
- CLIP_URL=http://clip:8000
|
||||
- BLIP_URL=http://blip:8000
|
||||
- YOLO_URL=http://yolo:8000
|
||||
- QDRANT_SVC_URL=http://qdrant-svc:8000
|
||||
- VISION_TIMEOUT=300
|
||||
- MAX_IMAGE_BYTES=52428800
|
||||
depends_on:
|
||||
- clip
|
||||
- blip
|
||||
- yolo
|
||||
- qdrant-svc
|
||||
|
||||
qdrant:
|
||||
image: qdrant/qdrant:latest
|
||||
ports:
|
||||
- "6333:6333"
|
||||
volumes:
|
||||
- qdrant_data:/qdrant/storage
|
||||
environment:
|
||||
- QDRANT__SERVICE__GRPC_PORT=6334
|
||||
|
||||
qdrant-svc:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: qdrant/Dockerfile
|
||||
environment:
|
||||
- QDRANT_HOST=qdrant
|
||||
- QDRANT_PORT=6333
|
||||
- CLIP_URL=http://clip:8000
|
||||
- COLLECTION_NAME=images
|
||||
- VECTOR_DIM=512
|
||||
depends_on:
|
||||
- qdrant
|
||||
- clip
|
||||
|
||||
clip:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: clip/Dockerfile
|
||||
environment:
|
||||
- MODEL_NAME=ViT-B-32
|
||||
- MODEL_PRETRAINED=openai
|
||||
|
||||
blip:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: blip/Dockerfile
|
||||
environment:
|
||||
#- BLIP_MODEL=Salesforce/blip-image-captioning-base
|
||||
- BLIP_MODEL=Salesforce/blip-image-captioning-small
|
||||
|
||||
yolo:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: yolo/Dockerfile
|
||||
environment:
|
||||
- YOLO_MODEL=yolov8n.pt
|
||||
|
||||
volumes:
|
||||
qdrant_data:
|
||||
16
gateway/Dockerfile
Normal file
16
gateway/Dockerfile
Normal file
@@ -0,0 +1,16 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY gateway/requirements.txt /app/requirements.txt
|
||||
RUN pip install --no-cache-dir -r /app/requirements.txt
|
||||
|
||||
COPY gateway /app
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
289
gateway/main.py
Normal file
289
gateway/main.py
Normal file
@@ -0,0 +1,289 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
CLIP_URL = os.getenv("CLIP_URL", "http://clip:8000")
|
||||
BLIP_URL = os.getenv("BLIP_URL", "http://blip:8000")
|
||||
YOLO_URL = os.getenv("YOLO_URL", "http://yolo:8000")
|
||||
QDRANT_SVC_URL = os.getenv("QDRANT_SVC_URL", "http://qdrant-svc:8000")
|
||||
VISION_TIMEOUT = float(os.getenv("VISION_TIMEOUT", "20"))
|
||||
|
||||
app = FastAPI(title="Skinbase Vision Gateway", version="1.0.0")
|
||||
|
||||
|
||||
class ClipRequest(BaseModel):
|
||||
url: Optional[str] = None
|
||||
limit: int = Field(default=5, ge=1, le=50)
|
||||
threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class BlipRequest(BaseModel):
|
||||
url: Optional[str] = None
|
||||
variants: int = Field(default=3, ge=0, le=10)
|
||||
max_length: int = Field(default=60, ge=10, le=200)
|
||||
|
||||
|
||||
class YoloRequest(BaseModel):
|
||||
url: Optional[str] = None
|
||||
conf: float = Field(default=0.25, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
async def _get_health(client: httpx.AsyncClient, base: str) -> Dict[str, Any]:
|
||||
try:
|
||||
r = await client.get(f"{base}/health")
|
||||
return r.json() if r.status_code == 200 else {"status": "bad", "code": r.status_code}
|
||||
except Exception:
|
||||
return {"status": "unreachable"}
|
||||
|
||||
|
||||
async def _post_json(client: httpx.AsyncClient, url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
r = await client.post(url, json=payload)
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:200]}")
|
||||
return r.json()
|
||||
|
||||
|
||||
async def _post_file(client: httpx.AsyncClient, url: str, data: bytes, fields: Dict[str, Any]) -> Dict[str, Any]:
|
||||
files = {"file": ("image", data, "application/octet-stream")}
|
||||
r = await client.post(url, data={k: str(v) for k, v in fields.items()}, files=files)
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:200]}")
|
||||
return r.json()
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
clip_h, blip_h, yolo_h, qdrant_h = await asyncio.gather(
|
||||
_get_health(client, CLIP_URL),
|
||||
_get_health(client, BLIP_URL),
|
||||
_get_health(client, YOLO_URL),
|
||||
_get_health(client, QDRANT_SVC_URL),
|
||||
)
|
||||
return {"status": "ok", "services": {"clip": clip_h, "blip": blip_h, "yolo": yolo_h, "qdrant": qdrant_h}}
|
||||
|
||||
|
||||
# ---- Individual analyze endpoints (URL) ----
|
||||
|
||||
@app.post("/analyze/clip")
|
||||
async def analyze_clip(req: ClipRequest):
|
||||
if not req.url:
|
||||
raise HTTPException(400, "url is required")
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{CLIP_URL}/analyze", req.model_dump())
|
||||
|
||||
|
||||
@app.post("/analyze/blip")
|
||||
async def analyze_blip(req: BlipRequest):
|
||||
if not req.url:
|
||||
raise HTTPException(400, "url is required")
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{BLIP_URL}/caption", req.model_dump())
|
||||
|
||||
|
||||
@app.post("/analyze/yolo")
|
||||
async def analyze_yolo(req: YoloRequest):
|
||||
if not req.url:
|
||||
raise HTTPException(400, "url is required")
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{YOLO_URL}/detect", req.model_dump())
|
||||
|
||||
|
||||
# ---- Individual analyze endpoints (file upload) ----
|
||||
|
||||
|
||||
@app.post("/analyze/clip/file")
|
||||
async def analyze_clip_file(
|
||||
file: UploadFile = File(...),
|
||||
limit: int = Form(5),
|
||||
threshold: Optional[float] = Form(None),
|
||||
):
|
||||
data = await file.read()
|
||||
fields: Dict[str, Any] = {"limit": int(limit)}
|
||||
if threshold is not None:
|
||||
fields["threshold"] = float(threshold)
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_file(client, f"{CLIP_URL}/analyze/file", data, fields)
|
||||
|
||||
|
||||
@app.post("/analyze/blip/file")
|
||||
async def analyze_blip_file(
|
||||
file: UploadFile = File(...),
|
||||
variants: int = Form(3),
|
||||
max_length: int = Form(60),
|
||||
):
|
||||
data = await file.read()
|
||||
fields = {"variants": int(variants), "max_length": int(max_length)}
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_file(client, f"{BLIP_URL}/caption/file", data, fields)
|
||||
|
||||
|
||||
@app.post("/analyze/yolo/file")
|
||||
async def analyze_yolo_file(
|
||||
file: UploadFile = File(...),
|
||||
conf: float = Form(0.25),
|
||||
):
|
||||
data = await file.read()
|
||||
fields = {"conf": float(conf)}
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_file(client, f"{YOLO_URL}/detect/file", data, fields)
|
||||
|
||||
|
||||
@app.post("/analyze/all")
|
||||
async def analyze_all(payload: Dict[str, Any]):
|
||||
url = payload.get("url")
|
||||
if not url:
|
||||
raise HTTPException(400, "url is required")
|
||||
|
||||
clip_req = {"url": url, "limit": int(payload.get("limit", 5)), "threshold": payload.get("threshold")}
|
||||
blip_req = {"url": url, "variants": int(payload.get("variants", 3)), "max_length": int(payload.get("max_length", 60))}
|
||||
yolo_req = {"url": url, "conf": float(payload.get("conf", 0.25))}
|
||||
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
clip_task = _post_json(client, f"{CLIP_URL}/analyze", clip_req)
|
||||
blip_task = _post_json(client, f"{BLIP_URL}/caption", blip_req)
|
||||
yolo_task = _post_json(client, f"{YOLO_URL}/detect", yolo_req)
|
||||
|
||||
clip_res, blip_res, yolo_res = await asyncio.gather(clip_task, blip_task, yolo_task)
|
||||
|
||||
return {"clip": clip_res, "blip": blip_res, "yolo": yolo_res}
|
||||
|
||||
|
||||
# ---- Vector / Qdrant endpoints ----
|
||||
|
||||
@app.post("/vectors/upsert")
|
||||
async def vectors_upsert(payload: Dict[str, Any]):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/upsert", payload)
|
||||
|
||||
|
||||
@app.post("/vectors/upsert/file")
|
||||
async def vectors_upsert_file(
|
||||
file: UploadFile = File(...),
|
||||
id: Optional[str] = Form(None),
|
||||
collection: Optional[str] = Form(None),
|
||||
metadata_json: Optional[str] = Form(None),
|
||||
):
|
||||
data = await file.read()
|
||||
fields: Dict[str, Any] = {}
|
||||
if id is not None:
|
||||
fields["id"] = id
|
||||
if collection is not None:
|
||||
fields["collection"] = collection
|
||||
if metadata_json is not None:
|
||||
fields["metadata_json"] = metadata_json
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_file(client, f"{QDRANT_SVC_URL}/upsert/file", data, fields)
|
||||
|
||||
|
||||
@app.post("/vectors/upsert/vector")
|
||||
async def vectors_upsert_vector(payload: Dict[str, Any]):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/upsert/vector", payload)
|
||||
|
||||
|
||||
@app.post("/vectors/search")
|
||||
async def vectors_search(payload: Dict[str, Any]):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/search", payload)
|
||||
|
||||
|
||||
@app.post("/vectors/search/file")
|
||||
async def vectors_search_file(
|
||||
file: UploadFile = File(...),
|
||||
limit: int = Form(5),
|
||||
score_threshold: Optional[float] = Form(None),
|
||||
collection: Optional[str] = Form(None),
|
||||
):
|
||||
data = await file.read()
|
||||
fields: Dict[str, Any] = {"limit": int(limit)}
|
||||
if score_threshold is not None:
|
||||
fields["score_threshold"] = float(score_threshold)
|
||||
if collection is not None:
|
||||
fields["collection"] = collection
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_file(client, f"{QDRANT_SVC_URL}/search/file", data, fields)
|
||||
|
||||
|
||||
@app.post("/vectors/search/vector")
|
||||
async def vectors_search_vector(payload: Dict[str, Any]):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/search/vector", payload)
|
||||
|
||||
|
||||
@app.post("/vectors/delete")
|
||||
async def vectors_delete(payload: Dict[str, Any]):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/delete", payload)
|
||||
|
||||
|
||||
@app.get("/vectors/collections")
|
||||
async def vectors_collections():
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
r = await client.get(f"{QDRANT_SVC_URL}/collections")
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}")
|
||||
return r.json()
|
||||
|
||||
|
||||
@app.post("/vectors/collections")
|
||||
async def vectors_create_collection(payload: Dict[str, Any]):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/collections", payload)
|
||||
|
||||
|
||||
@app.get("/vectors/collections/{name}")
|
||||
async def vectors_collection_info(name: str):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
r = await client.get(f"{QDRANT_SVC_URL}/collections/{name}")
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}")
|
||||
return r.json()
|
||||
|
||||
|
||||
@app.delete("/vectors/collections/{name}")
|
||||
async def vectors_delete_collection(name: str):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
r = await client.delete(f"{QDRANT_SVC_URL}/collections/{name}")
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}")
|
||||
return r.json()
|
||||
|
||||
|
||||
@app.get("/vectors/points/{point_id}")
|
||||
async def vectors_get_point(point_id: str, collection: Optional[str] = None):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
params = {}
|
||||
if collection:
|
||||
params["collection"] = collection
|
||||
r = await client.get(f"{QDRANT_SVC_URL}/points/{point_id}", params=params)
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}")
|
||||
return r.json()
|
||||
|
||||
|
||||
# ---- File-based universal analyze ----
|
||||
|
||||
@app.post("/analyze/all/file")
|
||||
async def analyze_all_file(
|
||||
file: UploadFile = File(...),
|
||||
limit: int = Form(5),
|
||||
variants: int = Form(3),
|
||||
conf: float = Form(0.25),
|
||||
max_length: int = Form(60),
|
||||
):
|
||||
data = await file.read()
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
clip_task = _post_file(client, f"{CLIP_URL}/analyze/file", data, {"limit": limit})
|
||||
blip_task = _post_file(client, f"{BLIP_URL}/caption/file", data, {"variants": variants, "max_length": max_length})
|
||||
yolo_task = _post_file(client, f"{YOLO_URL}/detect/file", data, {"conf": conf})
|
||||
|
||||
clip_res, blip_res, yolo_res = await asyncio.gather(clip_task, blip_task, yolo_task)
|
||||
|
||||
return {"clip": clip_res, "blip": blip_res, "yolo": yolo_res}
|
||||
4
gateway/requirements.txt
Normal file
4
gateway/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
fastapi==0.115.5
|
||||
uvicorn[standard]==0.30.6
|
||||
httpx==0.27.2
|
||||
python-multipart==0.0.9
|
||||
17
qdrant/Dockerfile
Normal file
17
qdrant/Dockerfile
Normal file
@@ -0,0 +1,17 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY qdrant/requirements.txt /app/requirements.txt
|
||||
RUN pip install --no-cache-dir -r /app/requirements.txt
|
||||
|
||||
COPY qdrant /app
|
||||
COPY common /app/common
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
358
qdrant/main.py
Normal file
358
qdrant/main.py
Normal file
@@ -0,0 +1,358 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||||
from pydantic import BaseModel, Field
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import (
|
||||
Distance,
|
||||
PointStruct,
|
||||
VectorParams,
|
||||
Filter,
|
||||
FieldCondition,
|
||||
MatchValue,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant")
|
||||
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
|
||||
CLIP_URL = os.getenv("CLIP_URL", "http://clip:8000")
|
||||
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "images")
|
||||
VECTOR_DIM = int(os.getenv("VECTOR_DIM", "512"))
|
||||
|
||||
app = FastAPI(title="Skinbase Qdrant Service", version="1.0.0")
|
||||
client: QdrantClient = None # type: ignore[assignment]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Startup / shutdown
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup():
|
||||
global client
|
||||
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
|
||||
_ensure_collection()
|
||||
|
||||
|
||||
def _ensure_collection():
|
||||
"""Create the default collection if it does not exist yet."""
|
||||
collections = [c.name for c in client.get_collections().collections]
|
||||
if COLLECTION_NAME not in collections:
|
||||
client.create_collection(
|
||||
collection_name=COLLECTION_NAME,
|
||||
vectors_config=VectorParams(size=VECTOR_DIM, distance=Distance.COSINE),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / Response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class UpsertUrlRequest(BaseModel):
|
||||
url: str
|
||||
id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
collection: Optional[str] = None
|
||||
|
||||
|
||||
class UpsertVectorRequest(BaseModel):
|
||||
vector: List[float]
|
||||
id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
collection: Optional[str] = None
|
||||
|
||||
|
||||
class SearchUrlRequest(BaseModel):
|
||||
url: str
|
||||
limit: int = Field(default=5, ge=1, le=100)
|
||||
score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
collection: Optional[str] = None
|
||||
filter_metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SearchVectorRequest(BaseModel):
|
||||
vector: List[float]
|
||||
limit: int = Field(default=5, ge=1, le=100)
|
||||
score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
collection: Optional[str] = None
|
||||
filter_metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class DeleteRequest(BaseModel):
|
||||
ids: List[str]
|
||||
collection: Optional[str] = None
|
||||
|
||||
|
||||
class CollectionRequest(BaseModel):
|
||||
name: str
|
||||
vector_dim: int = Field(default=512, ge=1)
|
||||
distance: str = Field(default="cosine")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _col(name: Optional[str]) -> str:
|
||||
return name or COLLECTION_NAME
|
||||
|
||||
|
||||
async def _embed_url(url: str) -> List[float]:
|
||||
"""Call the CLIP service to get an image embedding."""
|
||||
async with httpx.AsyncClient(timeout=30) as http:
|
||||
r = await http.post(f"{CLIP_URL}/embed", json={"url": url})
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(502, f"CLIP /embed error: {r.status_code} {r.text[:200]}")
|
||||
return r.json()["vector"]
|
||||
|
||||
|
||||
async def _embed_bytes(data: bytes) -> List[float]:
|
||||
"""Call the CLIP service to embed uploaded file bytes."""
|
||||
async with httpx.AsyncClient(timeout=30) as http:
|
||||
files = {"file": ("image", data, "application/octet-stream")}
|
||||
r = await http.post(f"{CLIP_URL}/embed/file", files=files)
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(502, f"CLIP /embed/file error: {r.status_code} {r.text[:200]}")
|
||||
return r.json()["vector"]
|
||||
|
||||
|
||||
def _build_filter(metadata: Dict[str, Any]) -> Optional[Filter]:
|
||||
if not metadata:
|
||||
return None
|
||||
conditions = [
|
||||
FieldCondition(key=k, match=MatchValue(value=v))
|
||||
for k, v in metadata.items()
|
||||
]
|
||||
return Filter(must=conditions)
|
||||
|
||||
|
||||
def _point_id(raw: Optional[str]) -> str:
|
||||
return raw or uuid.uuid4().hex
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
try:
|
||||
info = client.get_collections()
|
||||
names = [c.name for c in info.collections]
|
||||
return {"status": "ok", "qdrant": QDRANT_HOST, "collections": names}
|
||||
except Exception as e:
|
||||
return {"status": "error", "detail": str(e)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Collection management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.post("/collections")
|
||||
def create_collection(req: CollectionRequest):
|
||||
dist_map = {"cosine": Distance.COSINE, "euclid": Distance.EUCLID, "dot": Distance.DOT}
|
||||
dist = dist_map.get(req.distance.lower())
|
||||
if dist is None:
|
||||
raise HTTPException(400, f"Unknown distance: {req.distance}. Use cosine, euclid, or dot.")
|
||||
|
||||
collections = [c.name for c in client.get_collections().collections]
|
||||
if req.name in collections:
|
||||
raise HTTPException(409, f"Collection '{req.name}' already exists")
|
||||
|
||||
client.create_collection(
|
||||
collection_name=req.name,
|
||||
vectors_config=VectorParams(size=req.vector_dim, distance=dist),
|
||||
)
|
||||
return {"created": req.name, "vector_dim": req.vector_dim, "distance": req.distance}
|
||||
|
||||
|
||||
@app.get("/collections")
|
||||
def list_collections():
|
||||
info = client.get_collections()
|
||||
return {"collections": [c.name for c in info.collections]}
|
||||
|
||||
|
||||
@app.get("/collections/{name}")
|
||||
def collection_info(name: str):
|
||||
try:
|
||||
info = client.get_collection(name)
|
||||
return {
|
||||
"name": name,
|
||||
"vectors_count": info.vectors_count,
|
||||
"points_count": info.points_count,
|
||||
"status": info.status.value if info.status else None,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(404, str(e))
|
||||
|
||||
|
||||
@app.delete("/collections/{name}")
|
||||
def delete_collection(name: str):
|
||||
client.delete_collection(name)
|
||||
return {"deleted": name}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Upsert endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.post("/upsert")
|
||||
async def upsert_url(req: UpsertUrlRequest):
|
||||
"""Embed an image by URL via CLIP, then store the vector in Qdrant."""
|
||||
vector = await _embed_url(req.url)
|
||||
pid = _point_id(req.id)
|
||||
payload = {**req.metadata, "source_url": req.url}
|
||||
col = _col(req.collection)
|
||||
|
||||
client.upsert(
|
||||
collection_name=col,
|
||||
points=[PointStruct(id=pid, vector=vector, payload=payload)],
|
||||
)
|
||||
return {"id": pid, "collection": col, "dim": len(vector)}
|
||||
|
||||
|
||||
@app.post("/upsert/file")
|
||||
async def upsert_file(
|
||||
file: UploadFile = File(...),
|
||||
id: Optional[str] = Form(None),
|
||||
collection: Optional[str] = Form(None),
|
||||
metadata_json: Optional[str] = Form(None),
|
||||
):
|
||||
"""Embed an uploaded image via CLIP, then store the vector in Qdrant."""
|
||||
import json
|
||||
|
||||
data = await file.read()
|
||||
vector = await _embed_bytes(data)
|
||||
pid = _point_id(id)
|
||||
|
||||
payload: Dict[str, Any] = {}
|
||||
if metadata_json:
|
||||
try:
|
||||
payload = json.loads(metadata_json)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(400, "metadata_json must be valid JSON")
|
||||
|
||||
col = _col(collection)
|
||||
client.upsert(
|
||||
collection_name=col,
|
||||
points=[PointStruct(id=pid, vector=vector, payload=payload)],
|
||||
)
|
||||
return {"id": pid, "collection": col, "dim": len(vector)}
|
||||
|
||||
|
||||
@app.post("/upsert/vector")
|
||||
def upsert_vector(req: UpsertVectorRequest):
|
||||
"""Store a pre-computed vector directly (skip CLIP embedding)."""
|
||||
pid = _point_id(req.id)
|
||||
col = _col(req.collection)
|
||||
client.upsert(
|
||||
collection_name=col,
|
||||
points=[PointStruct(id=pid, vector=req.vector, payload=req.metadata)],
|
||||
)
|
||||
return {"id": pid, "collection": col, "dim": len(req.vector)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Search endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.post("/search")
|
||||
async def search_url(req: SearchUrlRequest):
|
||||
"""Embed an image by URL via CLIP, then search Qdrant for similar vectors."""
|
||||
vector = await _embed_url(req.url)
|
||||
return _do_search(vector, req.limit, req.score_threshold, req.collection, req.filter_metadata)
|
||||
|
||||
|
||||
@app.post("/search/file")
|
||||
async def search_file(
|
||||
file: UploadFile = File(...),
|
||||
limit: int = Form(5),
|
||||
score_threshold: Optional[float] = Form(None),
|
||||
collection: Optional[str] = Form(None),
|
||||
):
|
||||
"""Embed an uploaded image via CLIP, then search Qdrant for similar vectors."""
|
||||
data = await file.read()
|
||||
vector = await _embed_bytes(data)
|
||||
return _do_search(vector, int(limit), score_threshold, collection, {})
|
||||
|
||||
|
||||
@app.post("/search/vector")
|
||||
def search_vector(req: SearchVectorRequest):
|
||||
"""Search Qdrant using a pre-computed vector."""
|
||||
return _do_search(req.vector, req.limit, req.score_threshold, req.collection, req.filter_metadata)
|
||||
|
||||
|
||||
def _do_search(
|
||||
vector: List[float],
|
||||
limit: int,
|
||||
score_threshold: Optional[float],
|
||||
collection: Optional[str],
|
||||
filter_metadata: Dict[str, Any],
|
||||
):
|
||||
col = _col(collection)
|
||||
qfilter = _build_filter(filter_metadata)
|
||||
|
||||
results = client.query_points(
|
||||
collection_name=col,
|
||||
query=vector,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_filter=qfilter,
|
||||
)
|
||||
|
||||
hits = []
|
||||
for point in results.points:
|
||||
hits.append({
|
||||
"id": point.id,
|
||||
"score": point.score,
|
||||
"metadata": point.payload,
|
||||
})
|
||||
|
||||
return {"results": hits, "collection": col, "count": len(hits)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Delete points
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.post("/delete")
|
||||
def delete_points(req: DeleteRequest):
|
||||
col = _col(req.collection)
|
||||
client.delete(
|
||||
collection_name=col,
|
||||
points_selector=req.ids,
|
||||
)
|
||||
return {"deleted": req.ids, "collection": col}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Get point by ID
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/points/{point_id}")
|
||||
def get_point(point_id: str, collection: Optional[str] = None):
|
||||
col = _col(collection)
|
||||
try:
|
||||
points = client.retrieve(collection_name=col, ids=[point_id], with_vectors=True)
|
||||
if not points:
|
||||
raise HTTPException(404, f"Point '{point_id}' not found")
|
||||
p = points[0]
|
||||
return {
|
||||
"id": p.id,
|
||||
"vector": p.vector,
|
||||
"metadata": p.payload,
|
||||
"collection": col,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(404, str(e))
|
||||
8
qdrant/requirements.txt
Normal file
8
qdrant/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
fastapi==0.115.5
|
||||
uvicorn[standard]==0.30.6
|
||||
python-multipart==0.0.9
|
||||
requests==2.32.3
|
||||
pillow==10.4.0
|
||||
qdrant-client==1.12.1
|
||||
httpx==0.27.2
|
||||
numpy
|
||||
19
yolo/Dockerfile
Normal file
19
yolo/Dockerfile
Normal file
@@ -0,0 +1,19 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY yolo/requirements.txt /app/requirements.txt
|
||||
RUN pip install --no-cache-dir -r /app/requirements.txt
|
||||
|
||||
COPY yolo /app
|
||||
COPY common /app/common
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
70
yolo/main.py
Normal file
70
yolo/main.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||||
from pydantic import BaseModel, Field
|
||||
from ultralytics import YOLO
|
||||
|
||||
from common.image_io import fetch_url_bytes, bytes_to_pil, ImageLoadError
|
||||
|
||||
YOLO_MODEL = os.getenv("YOLO_MODEL", "yolov8n.pt")
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
app = FastAPI(title="Skinbase YOLO Service", version="1.0.0")
|
||||
|
||||
model = YOLO(YOLO_MODEL)
|
||||
|
||||
|
||||
class DetectRequest(BaseModel):
|
||||
url: Optional[str] = None
|
||||
conf: float = Field(default=0.25, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "ok", "device": DEVICE, "model": YOLO_MODEL}
|
||||
|
||||
|
||||
def _detect_bytes(data: bytes, conf: float):
|
||||
img = bytes_to_pil(data)
|
||||
|
||||
results = model(img)
|
||||
|
||||
best: Dict[str, float] = {}
|
||||
for r in results:
|
||||
for box in r.boxes:
|
||||
score = float(box.conf[0])
|
||||
if score < conf:
|
||||
continue
|
||||
cls_id = int(box.cls[0])
|
||||
label = model.names.get(cls_id, str(cls_id))
|
||||
if label not in best or best[label] < score:
|
||||
best[label] = score
|
||||
|
||||
detections = [{"label": k, "confidence": v} for k, v in best.items()]
|
||||
detections.sort(key=lambda x: x["confidence"], reverse=True)
|
||||
|
||||
return {"detections": detections, "model": YOLO_MODEL}
|
||||
|
||||
|
||||
@app.post("/detect")
|
||||
def detect(req: DetectRequest):
|
||||
if not req.url:
|
||||
raise HTTPException(400, "url is required")
|
||||
try:
|
||||
data = fetch_url_bytes(req.url)
|
||||
return _detect_bytes(data, float(req.conf))
|
||||
except ImageLoadError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
|
||||
|
||||
@app.post("/detect/file")
|
||||
async def detect_file(
|
||||
file: UploadFile = File(...),
|
||||
conf: float = Form(0.25),
|
||||
):
|
||||
data = await file.read()
|
||||
return _detect_bytes(data, float(conf))
|
||||
8
yolo/requirements.txt
Normal file
8
yolo/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
fastapi==0.115.5
|
||||
uvicorn[standard]==0.30.6
|
||||
python-multipart==0.0.9
|
||||
requests==2.32.3
|
||||
pillow==10.4.0
|
||||
torch==2.4.1
|
||||
torchvision==0.19.1
|
||||
ultralytics==8.3.5
|
||||
Reference in New Issue
Block a user