Files
vision/tests/test_gateway_llm.py

313 lines
11 KiB
Python

from __future__ import annotations
import importlib
import os
import unittest
from typing import Any, Dict, Optional
from unittest.mock import patch
import httpx
BASE_ENV = {
"API_KEY": "test-key",
"CLIP_URL": "http://clip:8000",
"BLIP_URL": "http://blip:8000",
"YOLO_URL": "http://yolo:8000",
"QDRANT_SVC_URL": "http://qdrant-svc:8000",
"CARD_RENDERER_URL": "http://card-renderer:8000",
"MATURITY_URL": "http://maturity:8000",
"LLM_URL": "http://llm:8080",
"LLM_TIMEOUT": "5",
"LLM_DEFAULT_MODEL": "qwen3-1.7b-instruct-q4_k_m",
"LLM_MAX_TOKENS_DEFAULT": "256",
"LLM_MAX_TOKENS_HARD_LIMIT": "1024",
"LLM_MAX_REQUEST_BYTES": "65536",
}
def load_gateway_module(*, llm_enabled: bool, extra_env: Optional[Dict[str, str]] = None):
env = BASE_ENV | {"LLM_ENABLED": "true" if llm_enabled else "false"}
if extra_env:
env |= extra_env
with patch.dict(os.environ, env, clear=False):
import gateway.main as gateway_main
return importlib.reload(gateway_main)
class StubUpstreamClient:
def __init__(
self,
*,
request_responses: Optional[Dict[tuple[str, str], httpx.Response]] = None,
get_responses: Optional[Dict[str, httpx.Response]] = None,
request_exception: Optional[Exception] = None,
get_exception: Optional[Exception] = None,
):
self.request_responses = request_responses or {}
self.get_responses = get_responses or {}
self.request_exception = request_exception
self.get_exception = get_exception
async def request(self, method: str, url: str, **_: Any) -> httpx.Response:
if self.request_exception is not None:
raise self.request_exception
response = self.request_responses.get((method.upper(), url))
if response is None:
return httpx.Response(404, json={"error": {"message": f"No stub for {method} {url}"}})
return response
async def get(self, url: str, **_: Any) -> httpx.Response:
if self.get_exception is not None:
raise self.get_exception
response = self.get_responses.get(url)
if response is None:
return httpx.Response(404, json={"detail": f"No stub for GET {url}"})
return response
class GatewayLLMTests(unittest.IsolatedAsyncioTestCase):
async def _request(
self,
module: Any,
method: str,
path: str,
*,
headers: Optional[Dict[str, str]] = None,
json_payload: Optional[Dict[str, Any]] = None,
content: Optional[bytes] = None,
) -> httpx.Response:
transport = httpx.ASGITransport(app=module.app)
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
return await client.request(method, path, headers=headers, json=json_payload, content=content)
async def test_llm_endpoint_requires_api_key(self):
module = load_gateway_module(llm_enabled=True)
response = await self._request(
module,
"POST",
"/ai/chat",
json_payload={"messages": [{"role": "user", "content": "hello"}]},
)
self.assertEqual(response.status_code, 401)
self.assertEqual(response.json()["error"]["code"], "unauthorized")
async def test_llm_disabled_returns_503(self):
module = load_gateway_module(llm_enabled=False)
response = await self._request(
module,
"POST",
"/ai/chat",
headers={"X-API-Key": "test-key"},
json_payload={"messages": [{"role": "user", "content": "hello"}]},
)
self.assertEqual(response.status_code, 503)
self.assertEqual(response.json()["error"]["code"], "llm_disabled")
async def test_unreachable_llm_returns_normalized_503(self):
module = load_gateway_module(llm_enabled=True)
stub_client = StubUpstreamClient(
request_exception=httpx.ConnectError("boom", request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions")),
)
with patch.object(module, "get_http_client", return_value=stub_client):
response = await self._request(
module,
"POST",
"/ai/chat",
headers={"X-API-Key": "test-key"},
json_payload={"messages": [{"role": "user", "content": "hello"}]},
)
self.assertEqual(response.status_code, 503)
self.assertEqual(response.json()["error"]["code"], "llm_unavailable")
async def test_validation_error_is_normalized(self):
module = load_gateway_module(llm_enabled=True)
response = await self._request(
module,
"POST",
"/ai/chat",
headers={"X-API-Key": "test-key"},
json_payload={"messages": []},
)
self.assertEqual(response.status_code, 422)
self.assertEqual(response.json()["error"]["code"], "validation_error")
async def test_invalid_json_returns_400(self):
module = load_gateway_module(llm_enabled=True)
response = await self._request(
module,
"POST",
"/v1/chat/completions",
headers={"X-API-Key": "test-key", "Content-Type": "application/json"},
content=b'{"messages": [',
)
self.assertEqual(response.status_code, 400)
self.assertEqual(response.json()["error"]["code"], "invalid_json")
async def test_oversized_payload_returns_413(self):
module = load_gateway_module(llm_enabled=True, extra_env={"LLM_MAX_REQUEST_BYTES": "64"})
response = await self._request(
module,
"POST",
"/v1/chat/completions",
headers={"X-API-Key": "test-key"},
json_payload={"messages": [{"role": "user", "content": "x" * 5000}]},
)
self.assertEqual(response.status_code, 413)
self.assertEqual(response.json()["error"]["code"], "payload_too_large")
async def test_ai_chat_normalizes_successful_response(self):
module = load_gateway_module(llm_enabled=True)
upstream_response = httpx.Response(
200,
json={
"id": "chatcmpl-1",
"object": "chat.completion",
"model": "qwen3-1.7b-instruct-q4_k_m",
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {"role": "assistant", "content": "Generated text here."},
}
],
"usage": {"prompt_tokens": 12, "completion_tokens": 8, "total_tokens": 20},
},
request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions"),
)
stub_client = StubUpstreamClient(
request_responses={("POST", f"{module.LLM_URL}/v1/chat/completions"): upstream_response},
)
with patch.object(module, "get_http_client", return_value=stub_client):
response = await self._request(
module,
"POST",
"/ai/chat",
headers={"X-API-Key": "test-key"},
json_payload={"messages": [{"role": "user", "content": "hello"}]},
)
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.json(),
{
"model": "qwen3-1.7b-instruct-q4_k_m",
"content": "Generated text here.",
"finish_reason": "stop",
"usage": {"prompt_tokens": 12, "completion_tokens": 8, "total_tokens": 20},
},
)
async def test_ai_health_reports_reachable_llm(self):
module = load_gateway_module(llm_enabled=True)
stub_client = StubUpstreamClient(
get_responses={
f"{module.LLM_URL}/health": httpx.Response(
200,
json={"status": "ok", "model": "Qwen3-1.7B-Instruct-Q4_K_M.gguf", "context_size": 4096, "threads": 4},
request=httpx.Request("GET", f"{module.LLM_URL}/health"),
)
},
)
with patch.object(module, "get_http_client", return_value=stub_client):
response = await self._request(
module,
"GET",
"/ai/health",
headers={"X-API-Key": "test-key"},
)
self.assertEqual(response.status_code, 200)
self.assertTrue(response.json()["reachable"])
self.assertEqual(response.json()["default_model"], "qwen3-1.7b-instruct-q4_k_m")
async def test_timeout_returns_504(self):
module = load_gateway_module(llm_enabled=True)
stub_client = StubUpstreamClient(
request_exception=httpx.ReadTimeout("timeout", request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions")),
)
with patch.object(module, "get_http_client", return_value=stub_client):
response = await self._request(
module,
"POST",
"/ai/chat",
headers={"X-API-Key": "test-key"},
json_payload={"messages": [{"role": "user", "content": "hello"}]},
)
self.assertEqual(response.status_code, 504)
self.assertEqual(response.json()["error"]["code"], "llm_timeout")
async def test_upstream_400_is_preserved(self):
module = load_gateway_module(llm_enabled=True)
bad_request_response = httpx.Response(
400,
json={"error": {"message": "Bad prompt"}},
request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions"),
)
stub_client = StubUpstreamClient(
request_responses={("POST", f"{module.LLM_URL}/v1/chat/completions"): bad_request_response},
)
with patch.object(module, "get_http_client", return_value=stub_client):
response = await self._request(
module,
"POST",
"/v1/chat/completions",
headers={"X-API-Key": "test-key"},
json_payload={"messages": [{"role": "user", "content": "hello"}]},
)
self.assertEqual(response.status_code, 400)
self.assertEqual(response.json()["error"]["code"], "llm_rejected_request")
async def test_models_endpoint_returns_upstream_metadata(self):
module = load_gateway_module(llm_enabled=True)
models_response = httpx.Response(
200,
json={
"object": "list",
"data": [
{
"id": "qwen3-1.7b-instruct-q4_k_m",
"object": "model",
"owned_by": "self-hosted",
}
],
},
request=httpx.Request("GET", f"{module.LLM_URL}/v1/models"),
)
stub_client = StubUpstreamClient(
request_responses={("GET", f"{module.LLM_URL}/v1/models"): models_response},
)
with patch.object(module, "get_http_client", return_value=stub_client):
response = await self._request(
module,
"GET",
"/v1/models",
headers={"X-API-Key": "test-key"},
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json()["data"][0]["id"], "qwen3-1.7b-instruct-q4_k_m")
if __name__ == "__main__":
unittest.main()