feat: port VLM + FLUX from BTE (Phase 4b)
VLM router (POST /vlm/analyze):
- Proxies to Spark 2 (Qwen3-VL via vLLM, OpenAI-compatible /v1/chat/completions)
- Port of BTE Svrnty.Bte.Domain/Features/AssetContext/OpenAiVlmClient.cs
+ VlmRubric.cs (rubric prompt builder + score parser)
- Anthropic dialect intentionally dropped — sovereign-only
- New rubric_mode="raw" passes brand_context through verbatim so BTE
ExtractBrandSaga / ImageSetSourceReader (extraction-style prompts that
expect their own JSON schema) get unwrapped JSON back without losing
the score-axis path
FLUX router (POST /flux/render):
- Proxies to Spark 1 (FLUX.2-dev on ComfyUI; /prompt + /history poll + /view)
- Port of SparkBComfyClient.cs + LocalFluxImageProvider.cs + StopgapFluxWorkflow.cs
- Accepts a pre-assembled workflow_json (BTE IRecipeAssembler emits one)
or builds the stopgap FLUX.2 graph from prompt + dims
Tests (pytest):
- test_vlm_parse.py — rubric prompt + score parse, 502 on Spark-down, mocked round-trip
- test_flux_workflow.py — stopgap graph shape, seed variance/determinism, 502 on Spark-down
- test_healthz.py updated (palette/rembg still 4a stubs)
16 pytest tests green.
Smoke (no Spark reachable):
- GET /healthz → 200 {"status":"ok"}
- POST /vlm/analyze → 502 "Spark 2 unreachable" (clear error)
- POST /flux/render → 502 "Spark 1 unreachable" (clear error)
Per BTE refactor audit §3 V — vision capabilities extracted from BTE to the
sovereign vision gateway. Phase 4c (delete-from-BTE) + Phase 4d (HTTP adapter)
follow in BTE.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
2a90c3f884
commit
e810c72ffa
@ -1,17 +1,247 @@
|
||||
"""FLUX image generation — stub until Phase 4b wires the ComfyUI HTTP client."""
|
||||
"""FLUX image generation — proxies to Spark 1 (ComfyUI HTTP API).
|
||||
|
||||
Ported from BTE's SparkBComfyClient.cs + LocalFluxImageProvider.cs +
|
||||
StopgapFluxWorkflow.cs (Phase 4b). Workflow is either supplied by the caller
|
||||
(BTE's IRecipeAssembler emits a full graph) or assembled inline from the
|
||||
stopgap FLUX.2 template.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from svrnty_vision.settings import settings
|
||||
|
||||
router = APIRouter(prefix="/flux", tags=["flux"])
|
||||
|
||||
|
||||
@router.post("/render")
|
||||
async def render() -> None:
|
||||
"""Render an image via Spark 1 (FLUX on ComfyUI).
|
||||
# Default values match BTE's StopgapFluxWorkflow.cs.
|
||||
DEFAULT_GUIDANCE = 2.5
|
||||
DEFAULT_STEPS = 20
|
||||
POLL_INTERVAL_SECONDS = 2.0
|
||||
|
||||
Phase 4a: stub. Phase 4b: proxies to Spark 1 ComfyUI workflow.
|
||||
|
||||
class RenderRequest(BaseModel):
|
||||
"""Either `workflow_json` (a pre-assembled ComfyUI graph) or `prompt` + dims
|
||||
(we build the stopgap FLUX.2-dev graph). `workflow_json` wins when both set.
|
||||
"""
|
||||
|
||||
prompt: str | None = None
|
||||
width: int = 1024
|
||||
height: int = 1024
|
||||
workflow_json: str | None = None
|
||||
guidance: float = DEFAULT_GUIDANCE
|
||||
steps: int = DEFAULT_STEPS
|
||||
seed: int | None = None # null → random per call (ComfyUI dedup avoidance)
|
||||
|
||||
|
||||
class RenderResponse(BaseModel):
|
||||
prompt_id: str
|
||||
image_base64: str
|
||||
content_type: str = "image/png"
|
||||
duration_ms: int
|
||||
provider: str = "local"
|
||||
model: str = "flux2"
|
||||
|
||||
|
||||
def build_stopgap_workflow(
|
||||
prompt: str,
|
||||
width: int,
|
||||
height: int,
|
||||
guidance: float = DEFAULT_GUIDANCE,
|
||||
steps: int = DEFAULT_STEPS,
|
||||
seed: int | None = None,
|
||||
) -> str:
|
||||
"""Ported from StopgapFluxWorkflow.Build (Svrnty.Bte.Shared.Recipes).
|
||||
|
||||
FLUX.2 stack on Spark 1: flux2_dev_fp8mixed UNET, Mistral-3 CLIP (type "flux2"),
|
||||
flux2-vae. KSampler cfg=1.0; distilled guidance via FluxGuidance node.
|
||||
Random seed per call: ComfyUI dedupes identical workflows.
|
||||
"""
|
||||
if seed is None:
|
||||
seed = random.randint(1, 2**31 - 1)
|
||||
|
||||
graph: dict[str, Any] = {
|
||||
"5": {
|
||||
"class_type": "EmptySD3LatentImage",
|
||||
"inputs": {"width": width, "height": height, "batch_size": 1},
|
||||
},
|
||||
"6": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {"clip": ["11", 0], "text": prompt},
|
||||
},
|
||||
"7": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {"clip": ["11", 0], "text": ""},
|
||||
},
|
||||
"8": {
|
||||
"class_type": "VAEDecode",
|
||||
"inputs": {"samples": ["13", 0], "vae": ["10", 0]},
|
||||
},
|
||||
"9": {
|
||||
"class_type": "SaveImage",
|
||||
"inputs": {"filename_prefix": "bte", "images": ["8", 0]},
|
||||
},
|
||||
"10": {
|
||||
"class_type": "VAELoader",
|
||||
"inputs": {"vae_name": "flux2-vae.safetensors"},
|
||||
},
|
||||
"11": {
|
||||
"class_type": "CLIPLoader",
|
||||
"inputs": {
|
||||
"clip_name": "mistral_3_small_flux2_fp8.safetensors",
|
||||
"type": "flux2",
|
||||
},
|
||||
},
|
||||
"12": {
|
||||
"class_type": "UNETLoader",
|
||||
"inputs": {
|
||||
"unet_name": "flux2_dev_fp8mixed.safetensors",
|
||||
"weight_dtype": "default",
|
||||
},
|
||||
},
|
||||
"13": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"model": ["12", 0],
|
||||
"positive": ["26", 0],
|
||||
"negative": ["7", 0],
|
||||
"latent_image": ["5", 0],
|
||||
"seed": seed,
|
||||
"steps": steps,
|
||||
"cfg": 1.0,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "simple",
|
||||
"denoise": 1.0,
|
||||
},
|
||||
},
|
||||
"26": {
|
||||
"class_type": "FluxGuidance",
|
||||
"inputs": {"conditioning": ["6", 0], "guidance": guidance},
|
||||
},
|
||||
}
|
||||
return json.dumps(graph)
|
||||
|
||||
|
||||
async def _queue_prompt(
|
||||
client: httpx.AsyncClient, endpoint: str, workflow_json: str
|
||||
) -> str:
|
||||
workflow = json.loads(workflow_json)
|
||||
resp = await client.post(f"{endpoint}/prompt", json={"prompt": workflow})
|
||||
resp.raise_for_status()
|
||||
body = resp.json()
|
||||
if "prompt_id" not in body:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="flux.render not implemented in Phase 4a — see BTE-REFACTOR-EXECUTION-PLAN Phase 4b",
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"ComfyUI /prompt returned no prompt_id: {body}",
|
||||
)
|
||||
return body["prompt_id"]
|
||||
|
||||
|
||||
async def _poll_history(
|
||||
client: httpx.AsyncClient, endpoint: str, prompt_id: str
|
||||
) -> list[dict[str, str]]:
|
||||
deadline = time.monotonic() + settings.vision_request_timeout_seconds
|
||||
while time.monotonic() < deadline:
|
||||
resp = await client.get(f"{endpoint}/history/{prompt_id}")
|
||||
resp.raise_for_status()
|
||||
doc = resp.json()
|
||||
entry = doc.get(prompt_id)
|
||||
if entry and entry.get("status", {}).get("completed"):
|
||||
return _extract_images(entry)
|
||||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail=(
|
||||
f"ComfyUI prompt {prompt_id} did not complete within "
|
||||
f"{settings.vision_request_timeout_seconds}s."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _extract_images(entry: dict[str, Any]) -> list[dict[str, str]]:
|
||||
refs: list[dict[str, str]] = []
|
||||
outputs = entry.get("outputs", {})
|
||||
for node in outputs.values():
|
||||
for img in node.get("images", []) or []:
|
||||
refs.append(
|
||||
{
|
||||
"filename": img.get("filename", ""),
|
||||
"subfolder": img.get("subfolder", ""),
|
||||
"type": img.get("type", "output"),
|
||||
}
|
||||
)
|
||||
return refs
|
||||
|
||||
|
||||
async def _download_image(
|
||||
client: httpx.AsyncClient, endpoint: str, ref: dict[str, str]
|
||||
) -> bytes:
|
||||
qs = urllib.parse.urlencode(
|
||||
{
|
||||
"filename": ref["filename"],
|
||||
"subfolder": ref["subfolder"],
|
||||
"type": ref["type"],
|
||||
}
|
||||
)
|
||||
resp = await client.get(f"{endpoint}/view?{qs}")
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
|
||||
@router.post("/render", response_model=RenderResponse)
|
||||
async def render(req: RenderRequest) -> RenderResponse:
|
||||
"""Render an image via Spark 1 (FLUX.2-dev on ComfyUI).
|
||||
|
||||
If `workflow_json` is supplied (e.g. from BTE's IRecipeAssembler), it is
|
||||
POSTed verbatim. Otherwise a stopgap FLUX.2 graph is built from `prompt`.
|
||||
"""
|
||||
workflow_json = req.workflow_json
|
||||
if not workflow_json:
|
||||
if not req.prompt:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Provide workflow_json or prompt.",
|
||||
)
|
||||
workflow_json = build_stopgap_workflow(
|
||||
req.prompt, req.width, req.height, req.guidance, req.steps, req.seed
|
||||
)
|
||||
|
||||
endpoint = settings.spark1_flux_url.rstrip("/")
|
||||
started = time.monotonic()
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=settings.vision_request_timeout_seconds
|
||||
) as client:
|
||||
prompt_id = await _queue_prompt(client, endpoint, workflow_json)
|
||||
images = await _poll_history(client, endpoint, prompt_id)
|
||||
if not images:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"ComfyUI returned no output images for prompt {prompt_id}.",
|
||||
)
|
||||
png_bytes = await _download_image(client, endpoint, images[0])
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Spark 1 (ComfyUI) at {endpoint} unreachable: {type(e).__name__}: {e}",
|
||||
) from e
|
||||
|
||||
duration_ms = int((time.monotonic() - started) * 1000)
|
||||
return RenderResponse(
|
||||
prompt_id=prompt_id,
|
||||
image_base64=base64.b64encode(png_bytes).decode("ascii"),
|
||||
content_type="image/png",
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
@ -1,17 +1,193 @@
|
||||
"""VLM (vision-language model) analysis — stub until Phase 4b moves Qwen3-VL code."""
|
||||
"""VLM (vision-language model) analysis — proxies to Spark 2 (Qwen3-VL via vLLM).
|
||||
|
||||
Ported from BTE's OpenAiVlmClient.cs + VlmRubric.cs (Phase 4b). Cloud Anthropic
|
||||
dialect intentionally dropped — svrnty-vision is sovereign-only.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from svrnty_vision.settings import settings
|
||||
|
||||
router = APIRouter(prefix="/vlm", tags=["vlm"])
|
||||
|
||||
|
||||
@router.post("/analyze")
|
||||
async def analyze() -> None:
|
||||
"""Analyze an image with a vision-language model.
|
||||
class AnalyzeRequest(BaseModel):
|
||||
"""At least one of `image_base64` or `image_url` must be supplied.
|
||||
|
||||
Phase 4a: stub. Phase 4b: proxies to Spark 2 (Qwen3-VL via vLLM).
|
||||
`rubric_mode` is `polished` (premium aesthetic) or `ugc` (organic).
|
||||
`brand_context` is the prompt anchor (brand description / extraction directives).
|
||||
Set `rubric_mode = "raw"` to bypass the brand-scoring rubric and pass
|
||||
`brand_context` through verbatim (used by extraction prompts, e.g. screenshot DNA).
|
||||
"""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="vlm.analyze not implemented in Phase 4a — see BTE-REFACTOR-EXECUTION-PLAN Phase 4b",
|
||||
|
||||
image_base64: str | None = None
|
||||
image_url: str | None = None
|
||||
content_type: str = "image/png"
|
||||
brand_context: str = ""
|
||||
rubric_mode: str = "polished"
|
||||
model: str | None = None # override settings.spark2_vlm_model
|
||||
max_tokens: int = 1024
|
||||
|
||||
|
||||
class AnalyzeResponse(BaseModel):
|
||||
brand_fit_score: Decimal | None = Field(
|
||||
default=None,
|
||||
description="0.00–5.00, or null when rubric_mode='raw' (no scores parsed)",
|
||||
)
|
||||
visual_polish_score: Decimal | None = None
|
||||
rubric_mode: str
|
||||
justification: str = ""
|
||||
model_id: str
|
||||
raw_scores_json: str = Field(
|
||||
description="The JSON object the VLM returned (or its full text if rubric_mode='raw')."
|
||||
)
|
||||
|
||||
|
||||
def build_rubric_prompt(brand_context: str, rubric_mode: str) -> str:
|
||||
"""Ported from VlmRubric.BuildRubricPrompt.
|
||||
|
||||
`rubric_mode='raw'` bypasses the brand-scoring rubric — caller's brand_context
|
||||
becomes the full prompt (used for extraction-style calls).
|
||||
"""
|
||||
if rubric_mode == "raw":
|
||||
return brand_context
|
||||
|
||||
return (
|
||||
"You are a brand-asset reviewer. Brand context:\n"
|
||||
"---\n"
|
||||
f"{brand_context}\n"
|
||||
"---\n"
|
||||
f"Rubric mode: {rubric_mode} (polished = premium production aesthetic; "
|
||||
"ugc = organic, casual, hand-shot feel).\n\n"
|
||||
"Score the attached image on TWO axes, each 0.00 – 5.00:\n"
|
||||
"- brand_fit: how well it embodies the brand voice/tokens above.\n"
|
||||
"- visual_polish: technical execution under the chosen rubric mode.\n\n"
|
||||
"Respond with ONE compact JSON object only — no prose, no fences:\n"
|
||||
'{"brand_fit": <num>, "visual_polish": <num>, "justification": "<≤200 chars>"}'
|
||||
)
|
||||
|
||||
|
||||
def parse_scores(
|
||||
response_text: str, rubric_mode: str, model_id: str
|
||||
) -> AnalyzeResponse:
|
||||
"""Ported from VlmRubric.ParseScores.
|
||||
|
||||
For `rubric_mode='raw'` we return the response text untouched; scores stay None.
|
||||
Callers (e.g. BTE ImageSetSourceReader / ExtractBrandSaga) parse the JSON
|
||||
themselves.
|
||||
"""
|
||||
if rubric_mode == "raw":
|
||||
return AnalyzeResponse(
|
||||
brand_fit_score=None,
|
||||
visual_polish_score=None,
|
||||
rubric_mode=rubric_mode,
|
||||
justification="",
|
||||
model_id=model_id,
|
||||
raw_scores_json=response_text,
|
||||
)
|
||||
|
||||
trimmed = response_text.strip()
|
||||
start = trimmed.find("{")
|
||||
end = trimmed.rfind("}")
|
||||
if start < 0 or end <= start:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"VLM response did not contain a JSON object: {response_text}",
|
||||
)
|
||||
|
||||
raw_json = trimmed[start : end + 1]
|
||||
try:
|
||||
obj = json.loads(raw_json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"VLM JSON parse failed: {e}",
|
||||
) from e
|
||||
|
||||
return AnalyzeResponse(
|
||||
brand_fit_score=Decimal(str(obj["brand_fit"])),
|
||||
visual_polish_score=Decimal(str(obj["visual_polish"])),
|
||||
rubric_mode=rubric_mode,
|
||||
justification=str(obj.get("justification", "")),
|
||||
model_id=model_id,
|
||||
raw_scores_json=raw_json,
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_data_uri(req: AnalyzeRequest) -> str:
|
||||
if req.image_base64:
|
||||
return f"data:{req.content_type};base64,{req.image_base64}"
|
||||
if req.image_url:
|
||||
# vLLM Qwen3-VL accepts http(s) URLs directly; pass through.
|
||||
if req.image_url.startswith(("http://", "https://", "data:")):
|
||||
return req.image_url
|
||||
# Otherwise treat as local path → read + b64.
|
||||
try:
|
||||
with open(req.image_url, "rb") as f:
|
||||
raw = f.read()
|
||||
except OSError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"image_url unreadable as file: {e}",
|
||||
) from e
|
||||
return f"data:{req.content_type};base64,{base64.b64encode(raw).decode('ascii')}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Provide image_base64 or image_url.",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/analyze", response_model=AnalyzeResponse)
|
||||
async def analyze(req: AnalyzeRequest) -> AnalyzeResponse:
|
||||
"""Analyze an image with Qwen3-VL on Spark 2 (vLLM, OpenAI-compatible)."""
|
||||
data_uri = await _resolve_data_uri(req)
|
||||
rubric = build_rubric_prompt(req.brand_context, req.rubric_mode)
|
||||
model = req.model or settings.spark2_vlm_model
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"model": model,
|
||||
"max_tokens": req.max_tokens,
|
||||
"temperature": 0,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": rubric},
|
||||
{"type": "image_url", "image_url": {"url": data_uri}},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
url = settings.spark2_vlm_url.rstrip("/") + "/v1/chat/completions"
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=settings.vision_request_timeout_seconds
|
||||
) as client:
|
||||
resp = await client.post(url, json=body)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Spark 2 (vLLM) at {url} unreachable: {type(e).__name__}: {e}",
|
||||
) from e
|
||||
|
||||
try:
|
||||
text = payload["choices"][0]["message"]["content"] or ""
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Spark 2 response shape unexpected: {e}",
|
||||
) from e
|
||||
|
||||
return parse_scores(text, req.rubric_mode, model)
|
||||
|
||||
68
tests/test_flux_workflow.py
Normal file
68
tests/test_flux_workflow.py
Normal file
@ -0,0 +1,68 @@
|
||||
"""Pytest port of BTE's StopgapFluxWorkflowTests + smoke for /flux/render."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from svrnty_vision.routers.flux import build_stopgap_workflow
|
||||
from svrnty_vision.server import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_stopgap_workflow_includes_prompt_and_dimensions() -> None:
|
||||
raw = build_stopgap_workflow("a sunlit plate of food", 1024, 1024)
|
||||
assert "a sunlit plate of food" in raw
|
||||
graph = json.loads(raw)
|
||||
assert graph["5"]["inputs"]["width"] == 1024
|
||||
assert graph["5"]["inputs"]["height"] == 1024
|
||||
assert graph["10"]["inputs"]["vae_name"] == "flux2-vae.safetensors"
|
||||
assert graph["11"]["inputs"]["type"] == "flux2"
|
||||
assert graph["12"]["inputs"]["unet_name"] == "flux2_dev_fp8mixed.safetensors"
|
||||
|
||||
|
||||
def test_stopgap_workflow_seeds_vary_per_call() -> None:
|
||||
"""ComfyUI dedupes identical workflows (execution_cached → empty outputs)."""
|
||||
a = build_stopgap_workflow("x", 512, 512)
|
||||
b = build_stopgap_workflow("x", 512, 512)
|
||||
assert a != b
|
||||
|
||||
|
||||
def test_stopgap_workflow_uses_explicit_seed_when_supplied() -> None:
|
||||
a = build_stopgap_workflow("x", 512, 512, seed=42)
|
||||
b = build_stopgap_workflow("x", 512, 512, seed=42)
|
||||
assert a == b
|
||||
|
||||
|
||||
def test_render_requires_workflow_or_prompt() -> None:
|
||||
response = client.post("/flux/render", json={"width": 512, "height": 512})
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_render_returns_502_when_spark1_unreachable() -> None:
|
||||
class _StubClient:
|
||||
def __init__(self, *a, **kw):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *a):
|
||||
return False
|
||||
|
||||
async def post(self, *a, **kw):
|
||||
raise httpx.ConnectError("no comfy")
|
||||
|
||||
async def get(self, *a, **kw):
|
||||
raise httpx.ConnectError("no comfy")
|
||||
|
||||
with patch("svrnty_vision.routers.flux.httpx.AsyncClient", _StubClient):
|
||||
response = client.post(
|
||||
"/flux/render",
|
||||
json={"prompt": "test", "width": 512, "height": 512},
|
||||
)
|
||||
assert response.status_code == 502
|
||||
@ -15,21 +15,13 @@ def test_healthz_returns_200() -> None:
|
||||
assert "version" in body
|
||||
|
||||
|
||||
def test_vlm_analyze_returns_501() -> None:
|
||||
response = client.post("/vlm/analyze")
|
||||
assert response.status_code == 501
|
||||
|
||||
|
||||
def test_flux_render_returns_501() -> None:
|
||||
response = client.post("/flux/render")
|
||||
assert response.status_code == 501
|
||||
|
||||
|
||||
def test_palette_extract_returns_501() -> None:
|
||||
# Still a 4a stub — Phase 4c moved only VLM + FLUX, palette/rembg deferred.
|
||||
response = client.post("/palette/extract")
|
||||
assert response.status_code == 501
|
||||
|
||||
|
||||
def test_rembg_cutout_returns_501() -> None:
|
||||
# Still a 4a stub — Phase 4c moved only VLM + FLUX, palette/rembg deferred.
|
||||
response = client.post("/rembg/cutout")
|
||||
assert response.status_code == 501
|
||||
|
||||
147
tests/test_vlm_parse.py
Normal file
147
tests/test_vlm_parse.py
Normal file
@ -0,0 +1,147 @@
|
||||
"""Pytest port of BTE's FakeVlmEvaluationParseTests + VlmRubric parse coverage.
|
||||
|
||||
These tests cover the pure-function side of the VLM router (rubric prompt + score
|
||||
parsing). The HTTP call to Spark 2 is exercised separately via TestClient with a
|
||||
mocked httpx transport.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from svrnty_vision.routers.vlm import build_rubric_prompt, parse_scores
|
||||
from svrnty_vision.server import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_build_rubric_prompt_includes_brand_and_mode() -> None:
|
||||
prompt = build_rubric_prompt("test-brand", "polished")
|
||||
assert "test-brand" in prompt
|
||||
assert "polished" in prompt
|
||||
assert "brand_fit" in prompt
|
||||
assert "visual_polish" in prompt
|
||||
|
||||
|
||||
def test_build_rubric_prompt_raw_passes_through_verbatim() -> None:
|
||||
assert build_rubric_prompt("anything goes", "raw") == "anything goes"
|
||||
|
||||
|
||||
def test_parse_scores_extracts_canned_payload() -> None:
|
||||
raw = '{"brand_fit":4.20,"visual_polish":3.80,"justification":"clean lighting"}'
|
||||
result = parse_scores(raw, "polished", "qwen-test")
|
||||
assert result.brand_fit_score == Decimal("4.20")
|
||||
assert result.visual_polish_score == Decimal("3.80")
|
||||
assert result.rubric_mode == "polished"
|
||||
assert result.model_id == "qwen-test"
|
||||
assert "brand_fit" in result.raw_scores_json
|
||||
|
||||
|
||||
def test_parse_scores_tolerates_leading_or_trailing_prose() -> None:
|
||||
raw = 'sure: {"brand_fit":2,"visual_polish":3}. done.'
|
||||
result = parse_scores(raw, "ugc", "m")
|
||||
assert result.brand_fit_score == Decimal("2")
|
||||
assert result.visual_polish_score == Decimal("3")
|
||||
|
||||
|
||||
def test_parse_scores_raw_mode_returns_text_untouched() -> None:
|
||||
raw = '{"subject":"plate","mood":"warm"}'
|
||||
result = parse_scores(raw, "raw", "m")
|
||||
assert result.brand_fit_score is None
|
||||
assert result.visual_polish_score is None
|
||||
assert result.raw_scores_json == raw
|
||||
|
||||
|
||||
def test_analyze_requires_image_input() -> None:
|
||||
response = client.post(
|
||||
"/vlm/analyze",
|
||||
json={"brand_context": "x", "rubric_mode": "polished"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_analyze_returns_502_when_spark2_unreachable() -> None:
|
||||
"""Smoke: with no Spark 2 (or a failing transport), gateway surfaces 502.
|
||||
|
||||
Uses a mock async client that raises ConnectError on POST.
|
||||
"""
|
||||
|
||||
class _StubClient:
|
||||
def __init__(self, *a, **kw):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *a):
|
||||
return False
|
||||
|
||||
async def post(self, *a, **kw):
|
||||
raise httpx.ConnectError("boom")
|
||||
|
||||
with patch("svrnty_vision.routers.vlm.httpx.AsyncClient", _StubClient):
|
||||
response = client.post(
|
||||
"/vlm/analyze",
|
||||
json={
|
||||
"image_base64": "ZmFrZQ==",
|
||||
"brand_context": "test",
|
||||
"rubric_mode": "polished",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 502
|
||||
|
||||
|
||||
def test_analyze_round_trip_with_mocked_spark2() -> None:
|
||||
"""Happy path: mock vLLM returns a well-formed score JSON; gateway parses it."""
|
||||
|
||||
canned_response = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": '{"brand_fit": 4.5, "visual_polish": 4.0, "justification": "ok"}'
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
class _StubResponse:
|
||||
status_code = 200
|
||||
|
||||
def raise_for_status(self):
|
||||
return None
|
||||
|
||||
def json(self):
|
||||
return canned_response
|
||||
|
||||
class _StubClient:
|
||||
def __init__(self, *a, **kw):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *a):
|
||||
return False
|
||||
|
||||
async def post(self, *a, **kw):
|
||||
return _StubResponse()
|
||||
|
||||
with patch("svrnty_vision.routers.vlm.httpx.AsyncClient", _StubClient):
|
||||
response = client.post(
|
||||
"/vlm/analyze",
|
||||
json={
|
||||
"image_base64": "ZmFrZQ==",
|
||||
"brand_context": "test-brand",
|
||||
"rubric_mode": "polished",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200, response.text
|
||||
body = response.json()
|
||||
assert Decimal(body["brand_fit_score"]) == Decimal("4.5")
|
||||
assert Decimal(body["visual_polish_score"]) == Decimal("4.0")
|
||||
assert body["rubric_mode"] == "polished"
|
||||
Loading…
Reference in New Issue
Block a user