feat: port VLM + FLUX from BTE (Phase 4b)
VLM router (POST /vlm/analyze):
- Proxies to Spark 2 (Qwen3-VL via vLLM, OpenAI-compatible /v1/chat/completions)
- Port of BTE Svrnty.Bte.Domain/Features/AssetContext/OpenAiVlmClient.cs
+ VlmRubric.cs (rubric prompt builder + score parser)
- Anthropic dialect intentionally dropped — sovereign-only
- New rubric_mode="raw" passes brand_context through verbatim so BTE
ExtractBrandSaga / ImageSetSourceReader (extraction-style prompts that
expect their own JSON schema) get unwrapped JSON back without losing
the score-axis path
FLUX router (POST /flux/render):
- Proxies to Spark 1 (FLUX.2-dev on ComfyUI; /prompt + /history poll + /view)
- Port of SparkBComfyClient.cs + LocalFluxImageProvider.cs + StopgapFluxWorkflow.cs
- Accepts a pre-assembled workflow_json (BTE IRecipeAssembler emits one)
or builds the stopgap FLUX.2 graph from prompt + dims
Tests (pytest):
- test_vlm_parse.py — rubric prompt + score parse, 502 on Spark-down, mocked round-trip
- test_flux_workflow.py — stopgap graph shape, seed variance/determinism, 502 on Spark-down
- test_healthz.py updated (palette/rembg still 4a stubs)
16 pytest tests green.
Smoke (no Spark reachable):
- GET /healthz → 200 {"status":"ok"}
- POST /vlm/analyze → 502 "Spark 2 unreachable" (clear error)
- POST /flux/render → 502 "Spark 1 unreachable" (clear error)
Per BTE refactor audit §3 V — vision capabilities extracted from BTE to the
sovereign vision gateway. Phase 4c (delete-from-BTE) + Phase 4d (HTTP adapter)
follow in BTE.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
2a90c3f884
commit
e810c72ffa
@ -1,17 +1,247 @@
|
|||||||
"""FLUX image generation — stub until Phase 4b wires the ComfyUI HTTP client."""
|
"""FLUX image generation — proxies to Spark 1 (ComfyUI HTTP API).
|
||||||
|
|
||||||
|
Ported from BTE's SparkBComfyClient.cs + LocalFluxImageProvider.cs +
|
||||||
|
StopgapFluxWorkflow.cs (Phase 4b). Workflow is either supplied by the caller
|
||||||
|
(BTE's IRecipeAssembler emits a full graph) or assembled inline from the
|
||||||
|
stopgap FLUX.2 template.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import urllib.parse
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
from fastapi import APIRouter, HTTPException, status
|
from fastapi import APIRouter, HTTPException, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from svrnty_vision.settings import settings
|
||||||
|
|
||||||
router = APIRouter(prefix="/flux", tags=["flux"])
|
router = APIRouter(prefix="/flux", tags=["flux"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("/render")
|
# Default values match BTE's StopgapFluxWorkflow.cs.
|
||||||
async def render() -> None:
|
DEFAULT_GUIDANCE = 2.5
|
||||||
"""Render an image via Spark 1 (FLUX on ComfyUI).
|
DEFAULT_STEPS = 20
|
||||||
|
POLL_INTERVAL_SECONDS = 2.0
|
||||||
|
|
||||||
Phase 4a: stub. Phase 4b: proxies to Spark 1 ComfyUI workflow.
|
|
||||||
|
class RenderRequest(BaseModel):
|
||||||
|
"""Either `workflow_json` (a pre-assembled ComfyUI graph) or `prompt` + dims
|
||||||
|
(we build the stopgap FLUX.2-dev graph). `workflow_json` wins when both set.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
prompt: str | None = None
|
||||||
|
width: int = 1024
|
||||||
|
height: int = 1024
|
||||||
|
workflow_json: str | None = None
|
||||||
|
guidance: float = DEFAULT_GUIDANCE
|
||||||
|
steps: int = DEFAULT_STEPS
|
||||||
|
seed: int | None = None # null → random per call (ComfyUI dedup avoidance)
|
||||||
|
|
||||||
|
|
||||||
|
class RenderResponse(BaseModel):
|
||||||
|
prompt_id: str
|
||||||
|
image_base64: str
|
||||||
|
content_type: str = "image/png"
|
||||||
|
duration_ms: int
|
||||||
|
provider: str = "local"
|
||||||
|
model: str = "flux2"
|
||||||
|
|
||||||
|
|
||||||
|
def build_stopgap_workflow(
|
||||||
|
prompt: str,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
guidance: float = DEFAULT_GUIDANCE,
|
||||||
|
steps: int = DEFAULT_STEPS,
|
||||||
|
seed: int | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Ported from StopgapFluxWorkflow.Build (Svrnty.Bte.Shared.Recipes).
|
||||||
|
|
||||||
|
FLUX.2 stack on Spark 1: flux2_dev_fp8mixed UNET, Mistral-3 CLIP (type "flux2"),
|
||||||
|
flux2-vae. KSampler cfg=1.0; distilled guidance via FluxGuidance node.
|
||||||
|
Random seed per call: ComfyUI dedupes identical workflows.
|
||||||
|
"""
|
||||||
|
if seed is None:
|
||||||
|
seed = random.randint(1, 2**31 - 1)
|
||||||
|
|
||||||
|
graph: dict[str, Any] = {
|
||||||
|
"5": {
|
||||||
|
"class_type": "EmptySD3LatentImage",
|
||||||
|
"inputs": {"width": width, "height": height, "batch_size": 1},
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"inputs": {"clip": ["11", 0], "text": prompt},
|
||||||
|
},
|
||||||
|
"7": {
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"inputs": {"clip": ["11", 0], "text": ""},
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"inputs": {"samples": ["13", 0], "vae": ["10", 0]},
|
||||||
|
},
|
||||||
|
"9": {
|
||||||
|
"class_type": "SaveImage",
|
||||||
|
"inputs": {"filename_prefix": "bte", "images": ["8", 0]},
|
||||||
|
},
|
||||||
|
"10": {
|
||||||
|
"class_type": "VAELoader",
|
||||||
|
"inputs": {"vae_name": "flux2-vae.safetensors"},
|
||||||
|
},
|
||||||
|
"11": {
|
||||||
|
"class_type": "CLIPLoader",
|
||||||
|
"inputs": {
|
||||||
|
"clip_name": "mistral_3_small_flux2_fp8.safetensors",
|
||||||
|
"type": "flux2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"12": {
|
||||||
|
"class_type": "UNETLoader",
|
||||||
|
"inputs": {
|
||||||
|
"unet_name": "flux2_dev_fp8mixed.safetensors",
|
||||||
|
"weight_dtype": "default",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"13": {
|
||||||
|
"class_type": "KSampler",
|
||||||
|
"inputs": {
|
||||||
|
"model": ["12", 0],
|
||||||
|
"positive": ["26", 0],
|
||||||
|
"negative": ["7", 0],
|
||||||
|
"latent_image": ["5", 0],
|
||||||
|
"seed": seed,
|
||||||
|
"steps": steps,
|
||||||
|
"cfg": 1.0,
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "simple",
|
||||||
|
"denoise": 1.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"26": {
|
||||||
|
"class_type": "FluxGuidance",
|
||||||
|
"inputs": {"conditioning": ["6", 0], "guidance": guidance},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return json.dumps(graph)
|
||||||
|
|
||||||
|
|
||||||
|
async def _queue_prompt(
|
||||||
|
client: httpx.AsyncClient, endpoint: str, workflow_json: str
|
||||||
|
) -> str:
|
||||||
|
workflow = json.loads(workflow_json)
|
||||||
|
resp = await client.post(f"{endpoint}/prompt", json={"prompt": workflow})
|
||||||
|
resp.raise_for_status()
|
||||||
|
body = resp.json()
|
||||||
|
if "prompt_id" not in body:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"ComfyUI /prompt returned no prompt_id: {body}",
|
||||||
|
)
|
||||||
|
return body["prompt_id"]
|
||||||
|
|
||||||
|
|
||||||
|
async def _poll_history(
|
||||||
|
client: httpx.AsyncClient, endpoint: str, prompt_id: str
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
deadline = time.monotonic() + settings.vision_request_timeout_seconds
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
resp = await client.get(f"{endpoint}/history/{prompt_id}")
|
||||||
|
resp.raise_for_status()
|
||||||
|
doc = resp.json()
|
||||||
|
entry = doc.get(prompt_id)
|
||||||
|
if entry and entry.get("status", {}).get("completed"):
|
||||||
|
return _extract_images(entry)
|
||||||
|
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||||
detail="flux.render not implemented in Phase 4a — see BTE-REFACTOR-EXECUTION-PLAN Phase 4b",
|
detail=(
|
||||||
|
f"ComfyUI prompt {prompt_id} did not complete within "
|
||||||
|
f"{settings.vision_request_timeout_seconds}s."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_images(entry: dict[str, Any]) -> list[dict[str, str]]:
|
||||||
|
refs: list[dict[str, str]] = []
|
||||||
|
outputs = entry.get("outputs", {})
|
||||||
|
for node in outputs.values():
|
||||||
|
for img in node.get("images", []) or []:
|
||||||
|
refs.append(
|
||||||
|
{
|
||||||
|
"filename": img.get("filename", ""),
|
||||||
|
"subfolder": img.get("subfolder", ""),
|
||||||
|
"type": img.get("type", "output"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return refs
|
||||||
|
|
||||||
|
|
||||||
|
async def _download_image(
|
||||||
|
client: httpx.AsyncClient, endpoint: str, ref: dict[str, str]
|
||||||
|
) -> bytes:
|
||||||
|
qs = urllib.parse.urlencode(
|
||||||
|
{
|
||||||
|
"filename": ref["filename"],
|
||||||
|
"subfolder": ref["subfolder"],
|
||||||
|
"type": ref["type"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
resp = await client.get(f"{endpoint}/view?{qs}")
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.content
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/render", response_model=RenderResponse)
|
||||||
|
async def render(req: RenderRequest) -> RenderResponse:
|
||||||
|
"""Render an image via Spark 1 (FLUX.2-dev on ComfyUI).
|
||||||
|
|
||||||
|
If `workflow_json` is supplied (e.g. from BTE's IRecipeAssembler), it is
|
||||||
|
POSTed verbatim. Otherwise a stopgap FLUX.2 graph is built from `prompt`.
|
||||||
|
"""
|
||||||
|
workflow_json = req.workflow_json
|
||||||
|
if not workflow_json:
|
||||||
|
if not req.prompt:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Provide workflow_json or prompt.",
|
||||||
|
)
|
||||||
|
workflow_json = build_stopgap_workflow(
|
||||||
|
req.prompt, req.width, req.height, req.guidance, req.steps, req.seed
|
||||||
|
)
|
||||||
|
|
||||||
|
endpoint = settings.spark1_flux_url.rstrip("/")
|
||||||
|
started = time.monotonic()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
timeout=settings.vision_request_timeout_seconds
|
||||||
|
) as client:
|
||||||
|
prompt_id = await _queue_prompt(client, endpoint, workflow_json)
|
||||||
|
images = await _poll_history(client, endpoint, prompt_id)
|
||||||
|
if not images:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"ComfyUI returned no output images for prompt {prompt_id}.",
|
||||||
|
)
|
||||||
|
png_bytes = await _download_image(client, endpoint, images[0])
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"Spark 1 (ComfyUI) at {endpoint} unreachable: {type(e).__name__}: {e}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
duration_ms = int((time.monotonic() - started) * 1000)
|
||||||
|
return RenderResponse(
|
||||||
|
prompt_id=prompt_id,
|
||||||
|
image_base64=base64.b64encode(png_bytes).decode("ascii"),
|
||||||
|
content_type="image/png",
|
||||||
|
duration_ms=duration_ms,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,17 +1,193 @@
|
|||||||
"""VLM (vision-language model) analysis — stub until Phase 4b moves Qwen3-VL code."""
|
"""VLM (vision-language model) analysis — proxies to Spark 2 (Qwen3-VL via vLLM).
|
||||||
|
|
||||||
|
Ported from BTE's OpenAiVlmClient.cs + VlmRubric.cs (Phase 4b). Cloud Anthropic
|
||||||
|
dialect intentionally dropped — svrnty-vision is sovereign-only.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
from decimal import Decimal
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
from fastapi import APIRouter, HTTPException, status
|
from fastapi import APIRouter, HTTPException, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from svrnty_vision.settings import settings
|
||||||
|
|
||||||
router = APIRouter(prefix="/vlm", tags=["vlm"])
|
router = APIRouter(prefix="/vlm", tags=["vlm"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("/analyze")
|
class AnalyzeRequest(BaseModel):
|
||||||
async def analyze() -> None:
|
"""At least one of `image_base64` or `image_url` must be supplied.
|
||||||
"""Analyze an image with a vision-language model.
|
|
||||||
|
|
||||||
Phase 4a: stub. Phase 4b: proxies to Spark 2 (Qwen3-VL via vLLM).
|
`rubric_mode` is `polished` (premium aesthetic) or `ugc` (organic).
|
||||||
|
`brand_context` is the prompt anchor (brand description / extraction directives).
|
||||||
|
Set `rubric_mode = "raw"` to bypass the brand-scoring rubric and pass
|
||||||
|
`brand_context` through verbatim (used by extraction prompts, e.g. screenshot DNA).
|
||||||
"""
|
"""
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
image_base64: str | None = None
|
||||||
detail="vlm.analyze not implemented in Phase 4a — see BTE-REFACTOR-EXECUTION-PLAN Phase 4b",
|
image_url: str | None = None
|
||||||
|
content_type: str = "image/png"
|
||||||
|
brand_context: str = ""
|
||||||
|
rubric_mode: str = "polished"
|
||||||
|
model: str | None = None # override settings.spark2_vlm_model
|
||||||
|
max_tokens: int = 1024
|
||||||
|
|
||||||
|
|
||||||
|
class AnalyzeResponse(BaseModel):
|
||||||
|
brand_fit_score: Decimal | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="0.00–5.00, or null when rubric_mode='raw' (no scores parsed)",
|
||||||
)
|
)
|
||||||
|
visual_polish_score: Decimal | None = None
|
||||||
|
rubric_mode: str
|
||||||
|
justification: str = ""
|
||||||
|
model_id: str
|
||||||
|
raw_scores_json: str = Field(
|
||||||
|
description="The JSON object the VLM returned (or its full text if rubric_mode='raw')."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_rubric_prompt(brand_context: str, rubric_mode: str) -> str:
|
||||||
|
"""Ported from VlmRubric.BuildRubricPrompt.
|
||||||
|
|
||||||
|
`rubric_mode='raw'` bypasses the brand-scoring rubric — caller's brand_context
|
||||||
|
becomes the full prompt (used for extraction-style calls).
|
||||||
|
"""
|
||||||
|
if rubric_mode == "raw":
|
||||||
|
return brand_context
|
||||||
|
|
||||||
|
return (
|
||||||
|
"You are a brand-asset reviewer. Brand context:\n"
|
||||||
|
"---\n"
|
||||||
|
f"{brand_context}\n"
|
||||||
|
"---\n"
|
||||||
|
f"Rubric mode: {rubric_mode} (polished = premium production aesthetic; "
|
||||||
|
"ugc = organic, casual, hand-shot feel).\n\n"
|
||||||
|
"Score the attached image on TWO axes, each 0.00 – 5.00:\n"
|
||||||
|
"- brand_fit: how well it embodies the brand voice/tokens above.\n"
|
||||||
|
"- visual_polish: technical execution under the chosen rubric mode.\n\n"
|
||||||
|
"Respond with ONE compact JSON object only — no prose, no fences:\n"
|
||||||
|
'{"brand_fit": <num>, "visual_polish": <num>, "justification": "<≤200 chars>"}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_scores(
|
||||||
|
response_text: str, rubric_mode: str, model_id: str
|
||||||
|
) -> AnalyzeResponse:
|
||||||
|
"""Ported from VlmRubric.ParseScores.
|
||||||
|
|
||||||
|
For `rubric_mode='raw'` we return the response text untouched; scores stay None.
|
||||||
|
Callers (e.g. BTE ImageSetSourceReader / ExtractBrandSaga) parse the JSON
|
||||||
|
themselves.
|
||||||
|
"""
|
||||||
|
if rubric_mode == "raw":
|
||||||
|
return AnalyzeResponse(
|
||||||
|
brand_fit_score=None,
|
||||||
|
visual_polish_score=None,
|
||||||
|
rubric_mode=rubric_mode,
|
||||||
|
justification="",
|
||||||
|
model_id=model_id,
|
||||||
|
raw_scores_json=response_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
trimmed = response_text.strip()
|
||||||
|
start = trimmed.find("{")
|
||||||
|
end = trimmed.rfind("}")
|
||||||
|
if start < 0 or end <= start:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"VLM response did not contain a JSON object: {response_text}",
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_json = trimmed[start : end + 1]
|
||||||
|
try:
|
||||||
|
obj = json.loads(raw_json)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"VLM JSON parse failed: {e}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
return AnalyzeResponse(
|
||||||
|
brand_fit_score=Decimal(str(obj["brand_fit"])),
|
||||||
|
visual_polish_score=Decimal(str(obj["visual_polish"])),
|
||||||
|
rubric_mode=rubric_mode,
|
||||||
|
justification=str(obj.get("justification", "")),
|
||||||
|
model_id=model_id,
|
||||||
|
raw_scores_json=raw_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_data_uri(req: AnalyzeRequest) -> str:
|
||||||
|
if req.image_base64:
|
||||||
|
return f"data:{req.content_type};base64,{req.image_base64}"
|
||||||
|
if req.image_url:
|
||||||
|
# vLLM Qwen3-VL accepts http(s) URLs directly; pass through.
|
||||||
|
if req.image_url.startswith(("http://", "https://", "data:")):
|
||||||
|
return req.image_url
|
||||||
|
# Otherwise treat as local path → read + b64.
|
||||||
|
try:
|
||||||
|
with open(req.image_url, "rb") as f:
|
||||||
|
raw = f.read()
|
||||||
|
except OSError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"image_url unreadable as file: {e}",
|
||||||
|
) from e
|
||||||
|
return f"data:{req.content_type};base64,{base64.b64encode(raw).decode('ascii')}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Provide image_base64 or image_url.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/analyze", response_model=AnalyzeResponse)
|
||||||
|
async def analyze(req: AnalyzeRequest) -> AnalyzeResponse:
|
||||||
|
"""Analyze an image with Qwen3-VL on Spark 2 (vLLM, OpenAI-compatible)."""
|
||||||
|
data_uri = await _resolve_data_uri(req)
|
||||||
|
rubric = build_rubric_prompt(req.brand_context, req.rubric_mode)
|
||||||
|
model = req.model or settings.spark2_vlm_model
|
||||||
|
|
||||||
|
body: dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"max_tokens": req.max_tokens,
|
||||||
|
"temperature": 0,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": rubric},
|
||||||
|
{"type": "image_url", "image_url": {"url": data_uri}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
url = settings.spark2_vlm_url.rstrip("/") + "/v1/chat/completions"
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
timeout=settings.vision_request_timeout_seconds
|
||||||
|
) as client:
|
||||||
|
resp = await client.post(url, json=body)
|
||||||
|
resp.raise_for_status()
|
||||||
|
payload = resp.json()
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"Spark 2 (vLLM) at {url} unreachable: {type(e).__name__}: {e}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
try:
|
||||||
|
text = payload["choices"][0]["message"]["content"] or ""
|
||||||
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"Spark 2 response shape unexpected: {e}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
return parse_scores(text, req.rubric_mode, model)
|
||||||
|
|||||||
68
tests/test_flux_workflow.py
Normal file
68
tests/test_flux_workflow.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
"""Pytest port of BTE's StopgapFluxWorkflowTests + smoke for /flux/render."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from svrnty_vision.routers.flux import build_stopgap_workflow
|
||||||
|
from svrnty_vision.server import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_stopgap_workflow_includes_prompt_and_dimensions() -> None:
|
||||||
|
raw = build_stopgap_workflow("a sunlit plate of food", 1024, 1024)
|
||||||
|
assert "a sunlit plate of food" in raw
|
||||||
|
graph = json.loads(raw)
|
||||||
|
assert graph["5"]["inputs"]["width"] == 1024
|
||||||
|
assert graph["5"]["inputs"]["height"] == 1024
|
||||||
|
assert graph["10"]["inputs"]["vae_name"] == "flux2-vae.safetensors"
|
||||||
|
assert graph["11"]["inputs"]["type"] == "flux2"
|
||||||
|
assert graph["12"]["inputs"]["unet_name"] == "flux2_dev_fp8mixed.safetensors"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stopgap_workflow_seeds_vary_per_call() -> None:
|
||||||
|
"""ComfyUI dedupes identical workflows (execution_cached → empty outputs)."""
|
||||||
|
a = build_stopgap_workflow("x", 512, 512)
|
||||||
|
b = build_stopgap_workflow("x", 512, 512)
|
||||||
|
assert a != b
|
||||||
|
|
||||||
|
|
||||||
|
def test_stopgap_workflow_uses_explicit_seed_when_supplied() -> None:
|
||||||
|
a = build_stopgap_workflow("x", 512, 512, seed=42)
|
||||||
|
b = build_stopgap_workflow("x", 512, 512, seed=42)
|
||||||
|
assert a == b
|
||||||
|
|
||||||
|
|
||||||
|
def test_render_requires_workflow_or_prompt() -> None:
|
||||||
|
response = client.post("/flux/render", json={"width": 512, "height": 512})
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_render_returns_502_when_spark1_unreachable() -> None:
|
||||||
|
class _StubClient:
|
||||||
|
def __init__(self, *a, **kw):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *a):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def post(self, *a, **kw):
|
||||||
|
raise httpx.ConnectError("no comfy")
|
||||||
|
|
||||||
|
async def get(self, *a, **kw):
|
||||||
|
raise httpx.ConnectError("no comfy")
|
||||||
|
|
||||||
|
with patch("svrnty_vision.routers.flux.httpx.AsyncClient", _StubClient):
|
||||||
|
response = client.post(
|
||||||
|
"/flux/render",
|
||||||
|
json={"prompt": "test", "width": 512, "height": 512},
|
||||||
|
)
|
||||||
|
assert response.status_code == 502
|
||||||
@ -15,21 +15,13 @@ def test_healthz_returns_200() -> None:
|
|||||||
assert "version" in body
|
assert "version" in body
|
||||||
|
|
||||||
|
|
||||||
def test_vlm_analyze_returns_501() -> None:
|
|
||||||
response = client.post("/vlm/analyze")
|
|
||||||
assert response.status_code == 501
|
|
||||||
|
|
||||||
|
|
||||||
def test_flux_render_returns_501() -> None:
|
|
||||||
response = client.post("/flux/render")
|
|
||||||
assert response.status_code == 501
|
|
||||||
|
|
||||||
|
|
||||||
def test_palette_extract_returns_501() -> None:
|
def test_palette_extract_returns_501() -> None:
|
||||||
|
# Still a 4a stub — Phase 4c moved only VLM + FLUX, palette/rembg deferred.
|
||||||
response = client.post("/palette/extract")
|
response = client.post("/palette/extract")
|
||||||
assert response.status_code == 501
|
assert response.status_code == 501
|
||||||
|
|
||||||
|
|
||||||
def test_rembg_cutout_returns_501() -> None:
|
def test_rembg_cutout_returns_501() -> None:
|
||||||
|
# Still a 4a stub — Phase 4c moved only VLM + FLUX, palette/rembg deferred.
|
||||||
response = client.post("/rembg/cutout")
|
response = client.post("/rembg/cutout")
|
||||||
assert response.status_code == 501
|
assert response.status_code == 501
|
||||||
|
|||||||
147
tests/test_vlm_parse.py
Normal file
147
tests/test_vlm_parse.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
"""Pytest port of BTE's FakeVlmEvaluationParseTests + VlmRubric parse coverage.
|
||||||
|
|
||||||
|
These tests cover the pure-function side of the VLM router (rubric prompt + score
|
||||||
|
parsing). The HTTP call to Spark 2 is exercised separately via TestClient with a
|
||||||
|
mocked httpx transport.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from decimal import Decimal
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from svrnty_vision.routers.vlm import build_rubric_prompt, parse_scores
|
||||||
|
from svrnty_vision.server import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_rubric_prompt_includes_brand_and_mode() -> None:
|
||||||
|
prompt = build_rubric_prompt("test-brand", "polished")
|
||||||
|
assert "test-brand" in prompt
|
||||||
|
assert "polished" in prompt
|
||||||
|
assert "brand_fit" in prompt
|
||||||
|
assert "visual_polish" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_rubric_prompt_raw_passes_through_verbatim() -> None:
|
||||||
|
assert build_rubric_prompt("anything goes", "raw") == "anything goes"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_scores_extracts_canned_payload() -> None:
|
||||||
|
raw = '{"brand_fit":4.20,"visual_polish":3.80,"justification":"clean lighting"}'
|
||||||
|
result = parse_scores(raw, "polished", "qwen-test")
|
||||||
|
assert result.brand_fit_score == Decimal("4.20")
|
||||||
|
assert result.visual_polish_score == Decimal("3.80")
|
||||||
|
assert result.rubric_mode == "polished"
|
||||||
|
assert result.model_id == "qwen-test"
|
||||||
|
assert "brand_fit" in result.raw_scores_json
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_scores_tolerates_leading_or_trailing_prose() -> None:
|
||||||
|
raw = 'sure: {"brand_fit":2,"visual_polish":3}. done.'
|
||||||
|
result = parse_scores(raw, "ugc", "m")
|
||||||
|
assert result.brand_fit_score == Decimal("2")
|
||||||
|
assert result.visual_polish_score == Decimal("3")
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_scores_raw_mode_returns_text_untouched() -> None:
|
||||||
|
raw = '{"subject":"plate","mood":"warm"}'
|
||||||
|
result = parse_scores(raw, "raw", "m")
|
||||||
|
assert result.brand_fit_score is None
|
||||||
|
assert result.visual_polish_score is None
|
||||||
|
assert result.raw_scores_json == raw
|
||||||
|
|
||||||
|
|
||||||
|
def test_analyze_requires_image_input() -> None:
|
||||||
|
response = client.post(
|
||||||
|
"/vlm/analyze",
|
||||||
|
json={"brand_context": "x", "rubric_mode": "polished"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_analyze_returns_502_when_spark2_unreachable() -> None:
|
||||||
|
"""Smoke: with no Spark 2 (or a failing transport), gateway surfaces 502.
|
||||||
|
|
||||||
|
Uses a mock async client that raises ConnectError on POST.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class _StubClient:
|
||||||
|
def __init__(self, *a, **kw):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *a):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def post(self, *a, **kw):
|
||||||
|
raise httpx.ConnectError("boom")
|
||||||
|
|
||||||
|
with patch("svrnty_vision.routers.vlm.httpx.AsyncClient", _StubClient):
|
||||||
|
response = client.post(
|
||||||
|
"/vlm/analyze",
|
||||||
|
json={
|
||||||
|
"image_base64": "ZmFrZQ==",
|
||||||
|
"brand_context": "test",
|
||||||
|
"rubric_mode": "polished",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 502
|
||||||
|
|
||||||
|
|
||||||
|
def test_analyze_round_trip_with_mocked_spark2() -> None:
|
||||||
|
"""Happy path: mock vLLM returns a well-formed score JSON; gateway parses it."""
|
||||||
|
|
||||||
|
canned_response = {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"content": '{"brand_fit": 4.5, "visual_polish": 4.0, "justification": "ok"}'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
class _StubResponse:
|
||||||
|
status_code = 200
|
||||||
|
|
||||||
|
def raise_for_status(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return canned_response
|
||||||
|
|
||||||
|
class _StubClient:
|
||||||
|
def __init__(self, *a, **kw):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *a):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def post(self, *a, **kw):
|
||||||
|
return _StubResponse()
|
||||||
|
|
||||||
|
with patch("svrnty_vision.routers.vlm.httpx.AsyncClient", _StubClient):
|
||||||
|
response = client.post(
|
||||||
|
"/vlm/analyze",
|
||||||
|
json={
|
||||||
|
"image_base64": "ZmFrZQ==",
|
||||||
|
"brand_context": "test-brand",
|
||||||
|
"rubric_mode": "polished",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
body = response.json()
|
||||||
|
assert Decimal(body["brand_fit_score"]) == Decimal("4.5")
|
||||||
|
assert Decimal(body["visual_polish_score"]) == Decimal("4.0")
|
||||||
|
assert body["rubric_mode"] == "polished"
|
||||||
Loading…
Reference in New Issue
Block a user