diff --git a/src/svrnty_vision/routers/flux.py b/src/svrnty_vision/routers/flux.py index 770731f..9afae86 100644 --- a/src/svrnty_vision/routers/flux.py +++ b/src/svrnty_vision/routers/flux.py @@ -1,17 +1,247 @@ -"""FLUX image generation — stub until Phase 4b wires the ComfyUI HTTP client.""" +"""FLUX image generation — proxies to Spark 1 (ComfyUI HTTP API). +Ported from BTE's SparkBComfyClient.cs + LocalFluxImageProvider.cs + +StopgapFluxWorkflow.cs (Phase 4b). Workflow is either supplied by the caller +(BTE's IRecipeAssembler emits a full graph) or assembled inline from the +stopgap FLUX.2 template. +""" + +from __future__ import annotations + +import asyncio +import base64 +import json +import random +import time +import urllib.parse +from typing import Any + +import httpx from fastapi import APIRouter, HTTPException, status +from pydantic import BaseModel, Field + +from svrnty_vision.settings import settings router = APIRouter(prefix="/flux", tags=["flux"]) -@router.post("/render") -async def render() -> None: - """Render an image via Spark 1 (FLUX on ComfyUI). +# Default values match BTE's StopgapFluxWorkflow.cs. +DEFAULT_GUIDANCE = 2.5 +DEFAULT_STEPS = 20 +POLL_INTERVAL_SECONDS = 2.0 - Phase 4a: stub. Phase 4b: proxies to Spark 1 ComfyUI workflow. + +class RenderRequest(BaseModel): + """Either `workflow_json` (a pre-assembled ComfyUI graph) or `prompt` + dims + (we build the stopgap FLUX.2-dev graph). `workflow_json` wins when both set. """ + + prompt: str | None = None + width: int = 1024 + height: int = 1024 + workflow_json: str | None = None + guidance: float = DEFAULT_GUIDANCE + steps: int = DEFAULT_STEPS + seed: int | None = None # null → random per call (ComfyUI dedup avoidance) + + +class RenderResponse(BaseModel): + prompt_id: str + image_base64: str + content_type: str = "image/png" + duration_ms: int + provider: str = "local" + model: str = "flux2" + + +def build_stopgap_workflow( + prompt: str, + width: int, + height: int, + guidance: float = DEFAULT_GUIDANCE, + steps: int = DEFAULT_STEPS, + seed: int | None = None, +) -> str: + """Ported from StopgapFluxWorkflow.Build (Svrnty.Bte.Shared.Recipes). + + FLUX.2 stack on Spark 1: flux2_dev_fp8mixed UNET, Mistral-3 CLIP (type "flux2"), + flux2-vae. KSampler cfg=1.0; distilled guidance via FluxGuidance node. + Random seed per call: ComfyUI dedupes identical workflows. + """ + if seed is None: + seed = random.randint(1, 2**31 - 1) + + graph: dict[str, Any] = { + "5": { + "class_type": "EmptySD3LatentImage", + "inputs": {"width": width, "height": height, "batch_size": 1}, + }, + "6": { + "class_type": "CLIPTextEncode", + "inputs": {"clip": ["11", 0], "text": prompt}, + }, + "7": { + "class_type": "CLIPTextEncode", + "inputs": {"clip": ["11", 0], "text": ""}, + }, + "8": { + "class_type": "VAEDecode", + "inputs": {"samples": ["13", 0], "vae": ["10", 0]}, + }, + "9": { + "class_type": "SaveImage", + "inputs": {"filename_prefix": "bte", "images": ["8", 0]}, + }, + "10": { + "class_type": "VAELoader", + "inputs": {"vae_name": "flux2-vae.safetensors"}, + }, + "11": { + "class_type": "CLIPLoader", + "inputs": { + "clip_name": "mistral_3_small_flux2_fp8.safetensors", + "type": "flux2", + }, + }, + "12": { + "class_type": "UNETLoader", + "inputs": { + "unet_name": "flux2_dev_fp8mixed.safetensors", + "weight_dtype": "default", + }, + }, + "13": { + "class_type": "KSampler", + "inputs": { + "model": ["12", 0], + "positive": ["26", 0], + "negative": ["7", 0], + "latent_image": ["5", 0], + "seed": seed, + "steps": steps, + "cfg": 1.0, + "sampler_name": "euler", + "scheduler": "simple", + "denoise": 1.0, + }, + }, + "26": { + "class_type": "FluxGuidance", + "inputs": {"conditioning": ["6", 0], "guidance": guidance}, + }, + } + return json.dumps(graph) + + +async def _queue_prompt( + client: httpx.AsyncClient, endpoint: str, workflow_json: str +) -> str: + workflow = json.loads(workflow_json) + resp = await client.post(f"{endpoint}/prompt", json={"prompt": workflow}) + resp.raise_for_status() + body = resp.json() + if "prompt_id" not in body: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"ComfyUI /prompt returned no prompt_id: {body}", + ) + return body["prompt_id"] + + +async def _poll_history( + client: httpx.AsyncClient, endpoint: str, prompt_id: str +) -> list[dict[str, str]]: + deadline = time.monotonic() + settings.vision_request_timeout_seconds + while time.monotonic() < deadline: + resp = await client.get(f"{endpoint}/history/{prompt_id}") + resp.raise_for_status() + doc = resp.json() + entry = doc.get(prompt_id) + if entry and entry.get("status", {}).get("completed"): + return _extract_images(entry) + await asyncio.sleep(POLL_INTERVAL_SECONDS) raise HTTPException( - status_code=status.HTTP_501_NOT_IMPLEMENTED, - detail="flux.render not implemented in Phase 4a — see BTE-REFACTOR-EXECUTION-PLAN Phase 4b", + status_code=status.HTTP_504_GATEWAY_TIMEOUT, + detail=( + f"ComfyUI prompt {prompt_id} did not complete within " + f"{settings.vision_request_timeout_seconds}s." + ), + ) + + +def _extract_images(entry: dict[str, Any]) -> list[dict[str, str]]: + refs: list[dict[str, str]] = [] + outputs = entry.get("outputs", {}) + for node in outputs.values(): + for img in node.get("images", []) or []: + refs.append( + { + "filename": img.get("filename", ""), + "subfolder": img.get("subfolder", ""), + "type": img.get("type", "output"), + } + ) + return refs + + +async def _download_image( + client: httpx.AsyncClient, endpoint: str, ref: dict[str, str] +) -> bytes: + qs = urllib.parse.urlencode( + { + "filename": ref["filename"], + "subfolder": ref["subfolder"], + "type": ref["type"], + } + ) + resp = await client.get(f"{endpoint}/view?{qs}") + resp.raise_for_status() + return resp.content + + +@router.post("/render", response_model=RenderResponse) +async def render(req: RenderRequest) -> RenderResponse: + """Render an image via Spark 1 (FLUX.2-dev on ComfyUI). + + If `workflow_json` is supplied (e.g. from BTE's IRecipeAssembler), it is + POSTed verbatim. Otherwise a stopgap FLUX.2 graph is built from `prompt`. + """ + workflow_json = req.workflow_json + if not workflow_json: + if not req.prompt: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Provide workflow_json or prompt.", + ) + workflow_json = build_stopgap_workflow( + req.prompt, req.width, req.height, req.guidance, req.steps, req.seed + ) + + endpoint = settings.spark1_flux_url.rstrip("/") + started = time.monotonic() + + try: + async with httpx.AsyncClient( + timeout=settings.vision_request_timeout_seconds + ) as client: + prompt_id = await _queue_prompt(client, endpoint, workflow_json) + images = await _poll_history(client, endpoint, prompt_id) + if not images: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"ComfyUI returned no output images for prompt {prompt_id}.", + ) + png_bytes = await _download_image(client, endpoint, images[0]) + except httpx.HTTPError as e: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"Spark 1 (ComfyUI) at {endpoint} unreachable: {type(e).__name__}: {e}", + ) from e + + duration_ms = int((time.monotonic() - started) * 1000) + return RenderResponse( + prompt_id=prompt_id, + image_base64=base64.b64encode(png_bytes).decode("ascii"), + content_type="image/png", + duration_ms=duration_ms, ) diff --git a/src/svrnty_vision/routers/vlm.py b/src/svrnty_vision/routers/vlm.py index e32a658..b7b8f39 100644 --- a/src/svrnty_vision/routers/vlm.py +++ b/src/svrnty_vision/routers/vlm.py @@ -1,17 +1,193 @@ -"""VLM (vision-language model) analysis — stub until Phase 4b moves Qwen3-VL code.""" +"""VLM (vision-language model) analysis — proxies to Spark 2 (Qwen3-VL via vLLM). +Ported from BTE's OpenAiVlmClient.cs + VlmRubric.cs (Phase 4b). Cloud Anthropic +dialect intentionally dropped — svrnty-vision is sovereign-only. +""" + +from __future__ import annotations + +import base64 +import json +from decimal import Decimal +from typing import Any + +import httpx from fastapi import APIRouter, HTTPException, status +from pydantic import BaseModel, Field + +from svrnty_vision.settings import settings router = APIRouter(prefix="/vlm", tags=["vlm"]) -@router.post("/analyze") -async def analyze() -> None: - """Analyze an image with a vision-language model. +class AnalyzeRequest(BaseModel): + """At least one of `image_base64` or `image_url` must be supplied. - Phase 4a: stub. Phase 4b: proxies to Spark 2 (Qwen3-VL via vLLM). + `rubric_mode` is `polished` (premium aesthetic) or `ugc` (organic). + `brand_context` is the prompt anchor (brand description / extraction directives). + Set `rubric_mode = "raw"` to bypass the brand-scoring rubric and pass + `brand_context` through verbatim (used by extraction prompts, e.g. screenshot DNA). """ - raise HTTPException( - status_code=status.HTTP_501_NOT_IMPLEMENTED, - detail="vlm.analyze not implemented in Phase 4a — see BTE-REFACTOR-EXECUTION-PLAN Phase 4b", + + image_base64: str | None = None + image_url: str | None = None + content_type: str = "image/png" + brand_context: str = "" + rubric_mode: str = "polished" + model: str | None = None # override settings.spark2_vlm_model + max_tokens: int = 1024 + + +class AnalyzeResponse(BaseModel): + brand_fit_score: Decimal | None = Field( + default=None, + description="0.00–5.00, or null when rubric_mode='raw' (no scores parsed)", ) + visual_polish_score: Decimal | None = None + rubric_mode: str + justification: str = "" + model_id: str + raw_scores_json: str = Field( + description="The JSON object the VLM returned (or its full text if rubric_mode='raw')." + ) + + +def build_rubric_prompt(brand_context: str, rubric_mode: str) -> str: + """Ported from VlmRubric.BuildRubricPrompt. + + `rubric_mode='raw'` bypasses the brand-scoring rubric — caller's brand_context + becomes the full prompt (used for extraction-style calls). + """ + if rubric_mode == "raw": + return brand_context + + return ( + "You are a brand-asset reviewer. Brand context:\n" + "---\n" + f"{brand_context}\n" + "---\n" + f"Rubric mode: {rubric_mode} (polished = premium production aesthetic; " + "ugc = organic, casual, hand-shot feel).\n\n" + "Score the attached image on TWO axes, each 0.00 – 5.00:\n" + "- brand_fit: how well it embodies the brand voice/tokens above.\n" + "- visual_polish: technical execution under the chosen rubric mode.\n\n" + "Respond with ONE compact JSON object only — no prose, no fences:\n" + '{"brand_fit": , "visual_polish": , "justification": "<≤200 chars>"}' + ) + + +def parse_scores( + response_text: str, rubric_mode: str, model_id: str +) -> AnalyzeResponse: + """Ported from VlmRubric.ParseScores. + + For `rubric_mode='raw'` we return the response text untouched; scores stay None. + Callers (e.g. BTE ImageSetSourceReader / ExtractBrandSaga) parse the JSON + themselves. + """ + if rubric_mode == "raw": + return AnalyzeResponse( + brand_fit_score=None, + visual_polish_score=None, + rubric_mode=rubric_mode, + justification="", + model_id=model_id, + raw_scores_json=response_text, + ) + + trimmed = response_text.strip() + start = trimmed.find("{") + end = trimmed.rfind("}") + if start < 0 or end <= start: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"VLM response did not contain a JSON object: {response_text}", + ) + + raw_json = trimmed[start : end + 1] + try: + obj = json.loads(raw_json) + except json.JSONDecodeError as e: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"VLM JSON parse failed: {e}", + ) from e + + return AnalyzeResponse( + brand_fit_score=Decimal(str(obj["brand_fit"])), + visual_polish_score=Decimal(str(obj["visual_polish"])), + rubric_mode=rubric_mode, + justification=str(obj.get("justification", "")), + model_id=model_id, + raw_scores_json=raw_json, + ) + + +async def _resolve_data_uri(req: AnalyzeRequest) -> str: + if req.image_base64: + return f"data:{req.content_type};base64,{req.image_base64}" + if req.image_url: + # vLLM Qwen3-VL accepts http(s) URLs directly; pass through. + if req.image_url.startswith(("http://", "https://", "data:")): + return req.image_url + # Otherwise treat as local path → read + b64. + try: + with open(req.image_url, "rb") as f: + raw = f.read() + except OSError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"image_url unreadable as file: {e}", + ) from e + return f"data:{req.content_type};base64,{base64.b64encode(raw).decode('ascii')}" + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Provide image_base64 or image_url.", + ) + + +@router.post("/analyze", response_model=AnalyzeResponse) +async def analyze(req: AnalyzeRequest) -> AnalyzeResponse: + """Analyze an image with Qwen3-VL on Spark 2 (vLLM, OpenAI-compatible).""" + data_uri = await _resolve_data_uri(req) + rubric = build_rubric_prompt(req.brand_context, req.rubric_mode) + model = req.model or settings.spark2_vlm_model + + body: dict[str, Any] = { + "model": model, + "max_tokens": req.max_tokens, + "temperature": 0, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": rubric}, + {"type": "image_url", "image_url": {"url": data_uri}}, + ], + } + ], + } + + url = settings.spark2_vlm_url.rstrip("/") + "/v1/chat/completions" + try: + async with httpx.AsyncClient( + timeout=settings.vision_request_timeout_seconds + ) as client: + resp = await client.post(url, json=body) + resp.raise_for_status() + payload = resp.json() + except httpx.HTTPError as e: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"Spark 2 (vLLM) at {url} unreachable: {type(e).__name__}: {e}", + ) from e + + try: + text = payload["choices"][0]["message"]["content"] or "" + except (KeyError, IndexError, TypeError) as e: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"Spark 2 response shape unexpected: {e}", + ) from e + + return parse_scores(text, req.rubric_mode, model) diff --git a/tests/test_flux_workflow.py b/tests/test_flux_workflow.py new file mode 100644 index 0000000..f7edf17 --- /dev/null +++ b/tests/test_flux_workflow.py @@ -0,0 +1,68 @@ +"""Pytest port of BTE's StopgapFluxWorkflowTests + smoke for /flux/render.""" + +from __future__ import annotations + +import json +from unittest.mock import patch + +import httpx +from fastapi.testclient import TestClient + +from svrnty_vision.routers.flux import build_stopgap_workflow +from svrnty_vision.server import app + +client = TestClient(app) + + +def test_stopgap_workflow_includes_prompt_and_dimensions() -> None: + raw = build_stopgap_workflow("a sunlit plate of food", 1024, 1024) + assert "a sunlit plate of food" in raw + graph = json.loads(raw) + assert graph["5"]["inputs"]["width"] == 1024 + assert graph["5"]["inputs"]["height"] == 1024 + assert graph["10"]["inputs"]["vae_name"] == "flux2-vae.safetensors" + assert graph["11"]["inputs"]["type"] == "flux2" + assert graph["12"]["inputs"]["unet_name"] == "flux2_dev_fp8mixed.safetensors" + + +def test_stopgap_workflow_seeds_vary_per_call() -> None: + """ComfyUI dedupes identical workflows (execution_cached → empty outputs).""" + a = build_stopgap_workflow("x", 512, 512) + b = build_stopgap_workflow("x", 512, 512) + assert a != b + + +def test_stopgap_workflow_uses_explicit_seed_when_supplied() -> None: + a = build_stopgap_workflow("x", 512, 512, seed=42) + b = build_stopgap_workflow("x", 512, 512, seed=42) + assert a == b + + +def test_render_requires_workflow_or_prompt() -> None: + response = client.post("/flux/render", json={"width": 512, "height": 512}) + assert response.status_code == 400 + + +def test_render_returns_502_when_spark1_unreachable() -> None: + class _StubClient: + def __init__(self, *a, **kw): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + async def post(self, *a, **kw): + raise httpx.ConnectError("no comfy") + + async def get(self, *a, **kw): + raise httpx.ConnectError("no comfy") + + with patch("svrnty_vision.routers.flux.httpx.AsyncClient", _StubClient): + response = client.post( + "/flux/render", + json={"prompt": "test", "width": 512, "height": 512}, + ) + assert response.status_code == 502 diff --git a/tests/test_healthz.py b/tests/test_healthz.py index 14747dd..020af7a 100644 --- a/tests/test_healthz.py +++ b/tests/test_healthz.py @@ -15,21 +15,13 @@ def test_healthz_returns_200() -> None: assert "version" in body -def test_vlm_analyze_returns_501() -> None: - response = client.post("/vlm/analyze") - assert response.status_code == 501 - - -def test_flux_render_returns_501() -> None: - response = client.post("/flux/render") - assert response.status_code == 501 - - def test_palette_extract_returns_501() -> None: + # Still a 4a stub — Phase 4c moved only VLM + FLUX, palette/rembg deferred. response = client.post("/palette/extract") assert response.status_code == 501 def test_rembg_cutout_returns_501() -> None: + # Still a 4a stub — Phase 4c moved only VLM + FLUX, palette/rembg deferred. response = client.post("/rembg/cutout") assert response.status_code == 501 diff --git a/tests/test_vlm_parse.py b/tests/test_vlm_parse.py new file mode 100644 index 0000000..05bcbd0 --- /dev/null +++ b/tests/test_vlm_parse.py @@ -0,0 +1,147 @@ +"""Pytest port of BTE's FakeVlmEvaluationParseTests + VlmRubric parse coverage. + +These tests cover the pure-function side of the VLM router (rubric prompt + score +parsing). The HTTP call to Spark 2 is exercised separately via TestClient with a +mocked httpx transport. +""" + +from __future__ import annotations + +from decimal import Decimal +from unittest.mock import AsyncMock, patch + +import httpx +import pytest +from fastapi.testclient import TestClient + +from svrnty_vision.routers.vlm import build_rubric_prompt, parse_scores +from svrnty_vision.server import app + +client = TestClient(app) + + +def test_build_rubric_prompt_includes_brand_and_mode() -> None: + prompt = build_rubric_prompt("test-brand", "polished") + assert "test-brand" in prompt + assert "polished" in prompt + assert "brand_fit" in prompt + assert "visual_polish" in prompt + + +def test_build_rubric_prompt_raw_passes_through_verbatim() -> None: + assert build_rubric_prompt("anything goes", "raw") == "anything goes" + + +def test_parse_scores_extracts_canned_payload() -> None: + raw = '{"brand_fit":4.20,"visual_polish":3.80,"justification":"clean lighting"}' + result = parse_scores(raw, "polished", "qwen-test") + assert result.brand_fit_score == Decimal("4.20") + assert result.visual_polish_score == Decimal("3.80") + assert result.rubric_mode == "polished" + assert result.model_id == "qwen-test" + assert "brand_fit" in result.raw_scores_json + + +def test_parse_scores_tolerates_leading_or_trailing_prose() -> None: + raw = 'sure: {"brand_fit":2,"visual_polish":3}. done.' + result = parse_scores(raw, "ugc", "m") + assert result.brand_fit_score == Decimal("2") + assert result.visual_polish_score == Decimal("3") + + +def test_parse_scores_raw_mode_returns_text_untouched() -> None: + raw = '{"subject":"plate","mood":"warm"}' + result = parse_scores(raw, "raw", "m") + assert result.brand_fit_score is None + assert result.visual_polish_score is None + assert result.raw_scores_json == raw + + +def test_analyze_requires_image_input() -> None: + response = client.post( + "/vlm/analyze", + json={"brand_context": "x", "rubric_mode": "polished"}, + ) + assert response.status_code == 400 + + +def test_analyze_returns_502_when_spark2_unreachable() -> None: + """Smoke: with no Spark 2 (or a failing transport), gateway surfaces 502. + + Uses a mock async client that raises ConnectError on POST. + """ + + class _StubClient: + def __init__(self, *a, **kw): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + async def post(self, *a, **kw): + raise httpx.ConnectError("boom") + + with patch("svrnty_vision.routers.vlm.httpx.AsyncClient", _StubClient): + response = client.post( + "/vlm/analyze", + json={ + "image_base64": "ZmFrZQ==", + "brand_context": "test", + "rubric_mode": "polished", + }, + ) + assert response.status_code == 502 + + +def test_analyze_round_trip_with_mocked_spark2() -> None: + """Happy path: mock vLLM returns a well-formed score JSON; gateway parses it.""" + + canned_response = { + "choices": [ + { + "message": { + "content": '{"brand_fit": 4.5, "visual_polish": 4.0, "justification": "ok"}' + } + } + ] + } + + class _StubResponse: + status_code = 200 + + def raise_for_status(self): + return None + + def json(self): + return canned_response + + class _StubClient: + def __init__(self, *a, **kw): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + async def post(self, *a, **kw): + return _StubResponse() + + with patch("svrnty_vision.routers.vlm.httpx.AsyncClient", _StubClient): + response = client.post( + "/vlm/analyze", + json={ + "image_base64": "ZmFrZQ==", + "brand_context": "test-brand", + "rubric_mode": "polished", + }, + ) + assert response.status_code == 200, response.text + body = response.json() + assert Decimal(body["brand_fit_score"]) == Decimal("4.5") + assert Decimal(body["visual_polish_score"]) == Decimal("4.0") + assert body["rubric_mode"] == "polished"