147 lines
4.3 KiB
Python
147 lines
4.3 KiB
Python
"""Pytest port of BTE's FakeVlmEvaluationParseTests + VlmRubric parse coverage.
|
|
|
|
These tests cover the pure-function side of the VLM router (rubric prompt + score
|
|
parsing). The HTTP call to Spark 2 is exercised separately via TestClient with a
|
|
mocked httpx transport.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from decimal import Decimal
|
|
from unittest.mock import patch
|
|
|
|
import httpx
|
|
from fastapi.testclient import TestClient
|
|
|
|
from svrnty_vision.routers.vlm import build_rubric_prompt, parse_scores
|
|
from svrnty_vision.server import app
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
def test_build_rubric_prompt_includes_brand_and_mode() -> None:
|
|
prompt = build_rubric_prompt("test-brand", "polished")
|
|
assert "test-brand" in prompt
|
|
assert "polished" in prompt
|
|
assert "brand_fit" in prompt
|
|
assert "visual_polish" in prompt
|
|
|
|
|
|
def test_build_rubric_prompt_raw_passes_through_verbatim() -> None:
|
|
assert build_rubric_prompt("anything goes", "raw") == "anything goes"
|
|
|
|
|
|
def test_parse_scores_extracts_canned_payload() -> None:
|
|
raw = '{"brand_fit":4.20,"visual_polish":3.80,"justification":"clean lighting"}'
|
|
result = parse_scores(raw, "polished", "qwen-test")
|
|
assert result.brand_fit_score == Decimal("4.20")
|
|
assert result.visual_polish_score == Decimal("3.80")
|
|
assert result.rubric_mode == "polished"
|
|
assert result.model_id == "qwen-test"
|
|
assert "brand_fit" in result.raw_scores_json
|
|
|
|
|
|
def test_parse_scores_tolerates_leading_or_trailing_prose() -> None:
|
|
raw = 'sure: {"brand_fit":2,"visual_polish":3}. done.'
|
|
result = parse_scores(raw, "ugc", "m")
|
|
assert result.brand_fit_score == Decimal("2")
|
|
assert result.visual_polish_score == Decimal("3")
|
|
|
|
|
|
def test_parse_scores_raw_mode_returns_text_untouched() -> None:
|
|
raw = '{"subject":"plate","mood":"warm"}'
|
|
result = parse_scores(raw, "raw", "m")
|
|
assert result.brand_fit_score is None
|
|
assert result.visual_polish_score is None
|
|
assert result.raw_scores_json == raw
|
|
|
|
|
|
def test_analyze_requires_image_input() -> None:
|
|
response = client.post(
|
|
"/vlm/analyze",
|
|
json={"brand_context": "x", "rubric_mode": "polished"},
|
|
)
|
|
assert response.status_code == 400
|
|
|
|
|
|
def test_analyze_returns_502_when_spark2_unreachable() -> None:
|
|
"""Smoke: with no Spark 2 (or a failing transport), gateway surfaces 502.
|
|
|
|
Uses a mock async client that raises ConnectError on POST.
|
|
"""
|
|
|
|
class _StubClient:
|
|
def __init__(self, *a, **kw):
|
|
pass
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, *a):
|
|
return False
|
|
|
|
async def post(self, *a, **kw):
|
|
raise httpx.ConnectError("boom")
|
|
|
|
with patch("svrnty_vision.routers.vlm.httpx.AsyncClient", _StubClient):
|
|
response = client.post(
|
|
"/vlm/analyze",
|
|
json={
|
|
"image_base64": "ZmFrZQ==",
|
|
"brand_context": "test",
|
|
"rubric_mode": "polished",
|
|
},
|
|
)
|
|
assert response.status_code == 502
|
|
|
|
|
|
def test_analyze_round_trip_with_mocked_spark2() -> None:
|
|
"""Happy path: mock vLLM returns a well-formed score JSON; gateway parses it."""
|
|
|
|
canned_response = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": '{"brand_fit": 4.5, "visual_polish": 4.0, "justification": "ok"}'
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
class _StubResponse:
|
|
status_code = 200
|
|
|
|
def raise_for_status(self):
|
|
return None
|
|
|
|
def json(self):
|
|
return canned_response
|
|
|
|
class _StubClient:
|
|
def __init__(self, *a, **kw):
|
|
pass
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, *a):
|
|
return False
|
|
|
|
async def post(self, *a, **kw):
|
|
return _StubResponse()
|
|
|
|
with patch("svrnty_vision.routers.vlm.httpx.AsyncClient", _StubClient):
|
|
response = client.post(
|
|
"/vlm/analyze",
|
|
json={
|
|
"image_base64": "ZmFrZQ==",
|
|
"brand_context": "test-brand",
|
|
"rubric_mode": "polished",
|
|
},
|
|
)
|
|
assert response.status_code == 200, response.text
|
|
body = response.json()
|
|
assert Decimal(body["brand_fit_score"]) == Decimal("4.5")
|
|
assert Decimal(body["visual_polish_score"]) == Decimal("4.0")
|
|
assert body["rubric_mode"] == "polished"
|