313 lines
11 KiB
Python
313 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
import importlib
|
|
import os
|
|
import unittest
|
|
from typing import Any, Dict, Optional
|
|
from unittest.mock import patch
|
|
|
|
import httpx
|
|
|
|
|
|
BASE_ENV = {
|
|
"API_KEY": "test-key",
|
|
"CLIP_URL": "http://clip:8000",
|
|
"BLIP_URL": "http://blip:8000",
|
|
"YOLO_URL": "http://yolo:8000",
|
|
"QDRANT_SVC_URL": "http://qdrant-svc:8000",
|
|
"CARD_RENDERER_URL": "http://card-renderer:8000",
|
|
"MATURITY_URL": "http://maturity:8000",
|
|
"LLM_URL": "http://llm:8080",
|
|
"LLM_TIMEOUT": "5",
|
|
"LLM_DEFAULT_MODEL": "qwen3-1.7b-instruct-q4_k_m",
|
|
"LLM_MAX_TOKENS_DEFAULT": "256",
|
|
"LLM_MAX_TOKENS_HARD_LIMIT": "1024",
|
|
"LLM_MAX_REQUEST_BYTES": "65536",
|
|
}
|
|
|
|
|
|
def load_gateway_module(*, llm_enabled: bool, extra_env: Optional[Dict[str, str]] = None):
|
|
env = BASE_ENV | {"LLM_ENABLED": "true" if llm_enabled else "false"}
|
|
if extra_env:
|
|
env |= extra_env
|
|
with patch.dict(os.environ, env, clear=False):
|
|
import gateway.main as gateway_main
|
|
|
|
return importlib.reload(gateway_main)
|
|
|
|
|
|
class StubUpstreamClient:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
request_responses: Optional[Dict[tuple[str, str], httpx.Response]] = None,
|
|
get_responses: Optional[Dict[str, httpx.Response]] = None,
|
|
request_exception: Optional[Exception] = None,
|
|
get_exception: Optional[Exception] = None,
|
|
):
|
|
self.request_responses = request_responses or {}
|
|
self.get_responses = get_responses or {}
|
|
self.request_exception = request_exception
|
|
self.get_exception = get_exception
|
|
|
|
async def request(self, method: str, url: str, **_: Any) -> httpx.Response:
|
|
if self.request_exception is not None:
|
|
raise self.request_exception
|
|
response = self.request_responses.get((method.upper(), url))
|
|
if response is None:
|
|
return httpx.Response(404, json={"error": {"message": f"No stub for {method} {url}"}})
|
|
return response
|
|
|
|
async def get(self, url: str, **_: Any) -> httpx.Response:
|
|
if self.get_exception is not None:
|
|
raise self.get_exception
|
|
response = self.get_responses.get(url)
|
|
if response is None:
|
|
return httpx.Response(404, json={"detail": f"No stub for GET {url}"})
|
|
return response
|
|
|
|
|
|
class GatewayLLMTests(unittest.IsolatedAsyncioTestCase):
|
|
async def _request(
|
|
self,
|
|
module: Any,
|
|
method: str,
|
|
path: str,
|
|
*,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
json_payload: Optional[Dict[str, Any]] = None,
|
|
content: Optional[bytes] = None,
|
|
) -> httpx.Response:
|
|
transport = httpx.ASGITransport(app=module.app)
|
|
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
|
|
return await client.request(method, path, headers=headers, json=json_payload, content=content)
|
|
|
|
async def test_llm_endpoint_requires_api_key(self):
|
|
module = load_gateway_module(llm_enabled=True)
|
|
|
|
response = await self._request(
|
|
module,
|
|
"POST",
|
|
"/ai/chat",
|
|
json_payload={"messages": [{"role": "user", "content": "hello"}]},
|
|
)
|
|
|
|
self.assertEqual(response.status_code, 401)
|
|
self.assertEqual(response.json()["error"]["code"], "unauthorized")
|
|
|
|
async def test_llm_disabled_returns_503(self):
|
|
module = load_gateway_module(llm_enabled=False)
|
|
|
|
response = await self._request(
|
|
module,
|
|
"POST",
|
|
"/ai/chat",
|
|
headers={"X-API-Key": "test-key"},
|
|
json_payload={"messages": [{"role": "user", "content": "hello"}]},
|
|
)
|
|
|
|
self.assertEqual(response.status_code, 503)
|
|
self.assertEqual(response.json()["error"]["code"], "llm_disabled")
|
|
|
|
async def test_unreachable_llm_returns_normalized_503(self):
|
|
module = load_gateway_module(llm_enabled=True)
|
|
stub_client = StubUpstreamClient(
|
|
request_exception=httpx.ConnectError("boom", request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions")),
|
|
)
|
|
|
|
with patch.object(module, "get_http_client", return_value=stub_client):
|
|
response = await self._request(
|
|
module,
|
|
"POST",
|
|
"/ai/chat",
|
|
headers={"X-API-Key": "test-key"},
|
|
json_payload={"messages": [{"role": "user", "content": "hello"}]},
|
|
)
|
|
|
|
self.assertEqual(response.status_code, 503)
|
|
self.assertEqual(response.json()["error"]["code"], "llm_unavailable")
|
|
|
|
async def test_validation_error_is_normalized(self):
|
|
module = load_gateway_module(llm_enabled=True)
|
|
|
|
response = await self._request(
|
|
module,
|
|
"POST",
|
|
"/ai/chat",
|
|
headers={"X-API-Key": "test-key"},
|
|
json_payload={"messages": []},
|
|
)
|
|
|
|
self.assertEqual(response.status_code, 422)
|
|
self.assertEqual(response.json()["error"]["code"], "validation_error")
|
|
|
|
async def test_invalid_json_returns_400(self):
|
|
module = load_gateway_module(llm_enabled=True)
|
|
|
|
response = await self._request(
|
|
module,
|
|
"POST",
|
|
"/v1/chat/completions",
|
|
headers={"X-API-Key": "test-key", "Content-Type": "application/json"},
|
|
content=b'{"messages": [',
|
|
)
|
|
|
|
self.assertEqual(response.status_code, 400)
|
|
self.assertEqual(response.json()["error"]["code"], "invalid_json")
|
|
|
|
async def test_oversized_payload_returns_413(self):
|
|
module = load_gateway_module(llm_enabled=True, extra_env={"LLM_MAX_REQUEST_BYTES": "64"})
|
|
|
|
response = await self._request(
|
|
module,
|
|
"POST",
|
|
"/v1/chat/completions",
|
|
headers={"X-API-Key": "test-key"},
|
|
json_payload={"messages": [{"role": "user", "content": "x" * 5000}]},
|
|
)
|
|
|
|
self.assertEqual(response.status_code, 413)
|
|
self.assertEqual(response.json()["error"]["code"], "payload_too_large")
|
|
|
|
async def test_ai_chat_normalizes_successful_response(self):
|
|
module = load_gateway_module(llm_enabled=True)
|
|
upstream_response = httpx.Response(
|
|
200,
|
|
json={
|
|
"id": "chatcmpl-1",
|
|
"object": "chat.completion",
|
|
"model": "qwen3-1.7b-instruct-q4_k_m",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"finish_reason": "stop",
|
|
"message": {"role": "assistant", "content": "Generated text here."},
|
|
}
|
|
],
|
|
"usage": {"prompt_tokens": 12, "completion_tokens": 8, "total_tokens": 20},
|
|
},
|
|
request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions"),
|
|
)
|
|
stub_client = StubUpstreamClient(
|
|
request_responses={("POST", f"{module.LLM_URL}/v1/chat/completions"): upstream_response},
|
|
)
|
|
|
|
with patch.object(module, "get_http_client", return_value=stub_client):
|
|
response = await self._request(
|
|
module,
|
|
"POST",
|
|
"/ai/chat",
|
|
headers={"X-API-Key": "test-key"},
|
|
json_payload={"messages": [{"role": "user", "content": "hello"}]},
|
|
)
|
|
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertEqual(
|
|
response.json(),
|
|
{
|
|
"model": "qwen3-1.7b-instruct-q4_k_m",
|
|
"content": "Generated text here.",
|
|
"finish_reason": "stop",
|
|
"usage": {"prompt_tokens": 12, "completion_tokens": 8, "total_tokens": 20},
|
|
},
|
|
)
|
|
|
|
async def test_ai_health_reports_reachable_llm(self):
|
|
module = load_gateway_module(llm_enabled=True)
|
|
stub_client = StubUpstreamClient(
|
|
get_responses={
|
|
f"{module.LLM_URL}/health": httpx.Response(
|
|
200,
|
|
json={"status": "ok", "model": "Qwen3-1.7B-Instruct-Q4_K_M.gguf", "context_size": 4096, "threads": 4},
|
|
request=httpx.Request("GET", f"{module.LLM_URL}/health"),
|
|
)
|
|
},
|
|
)
|
|
|
|
with patch.object(module, "get_http_client", return_value=stub_client):
|
|
response = await self._request(
|
|
module,
|
|
"GET",
|
|
"/ai/health",
|
|
headers={"X-API-Key": "test-key"},
|
|
)
|
|
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertTrue(response.json()["reachable"])
|
|
self.assertEqual(response.json()["default_model"], "qwen3-1.7b-instruct-q4_k_m")
|
|
|
|
async def test_timeout_returns_504(self):
|
|
module = load_gateway_module(llm_enabled=True)
|
|
stub_client = StubUpstreamClient(
|
|
request_exception=httpx.ReadTimeout("timeout", request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions")),
|
|
)
|
|
|
|
with patch.object(module, "get_http_client", return_value=stub_client):
|
|
response = await self._request(
|
|
module,
|
|
"POST",
|
|
"/ai/chat",
|
|
headers={"X-API-Key": "test-key"},
|
|
json_payload={"messages": [{"role": "user", "content": "hello"}]},
|
|
)
|
|
|
|
self.assertEqual(response.status_code, 504)
|
|
self.assertEqual(response.json()["error"]["code"], "llm_timeout")
|
|
|
|
async def test_upstream_400_is_preserved(self):
|
|
module = load_gateway_module(llm_enabled=True)
|
|
bad_request_response = httpx.Response(
|
|
400,
|
|
json={"error": {"message": "Bad prompt"}},
|
|
request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions"),
|
|
)
|
|
stub_client = StubUpstreamClient(
|
|
request_responses={("POST", f"{module.LLM_URL}/v1/chat/completions"): bad_request_response},
|
|
)
|
|
|
|
with patch.object(module, "get_http_client", return_value=stub_client):
|
|
response = await self._request(
|
|
module,
|
|
"POST",
|
|
"/v1/chat/completions",
|
|
headers={"X-API-Key": "test-key"},
|
|
json_payload={"messages": [{"role": "user", "content": "hello"}]},
|
|
)
|
|
|
|
self.assertEqual(response.status_code, 400)
|
|
self.assertEqual(response.json()["error"]["code"], "llm_rejected_request")
|
|
|
|
async def test_models_endpoint_returns_upstream_metadata(self):
|
|
module = load_gateway_module(llm_enabled=True)
|
|
models_response = httpx.Response(
|
|
200,
|
|
json={
|
|
"object": "list",
|
|
"data": [
|
|
{
|
|
"id": "qwen3-1.7b-instruct-q4_k_m",
|
|
"object": "model",
|
|
"owned_by": "self-hosted",
|
|
}
|
|
],
|
|
},
|
|
request=httpx.Request("GET", f"{module.LLM_URL}/v1/models"),
|
|
)
|
|
stub_client = StubUpstreamClient(
|
|
request_responses={("GET", f"{module.LLM_URL}/v1/models"): models_response},
|
|
)
|
|
|
|
with patch.object(module, "get_http_client", return_value=stub_client):
|
|
response = await self._request(
|
|
module,
|
|
"GET",
|
|
"/v1/models",
|
|
headers={"X-API-Key": "test-key"},
|
|
)
|
|
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertEqual(response.json()["data"][0]["id"], "qwen3-1.7b-instruct-q4_k_m")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main() |