"""Unit tests for POST /palette/extract.""" from __future__ import annotations import base64 import io import pytest from PIL import Image from svrnty_vision.routers.palette import PaletteRequest, PaletteResponse, extract # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_png_b64(color: tuple[int, int, int], size: int = 50) -> str: img = Image.new("RGB", (size, size), color=color) buf = io.BytesIO() img.save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode("ascii") # --------------------------------------------------------------------------- # Unit tests (pure function / TestClient — no network) # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_solid_red_dominant_is_red() -> None: req = PaletteRequest(image_base64=_make_png_b64((255, 0, 0)), color_count=3) resp: PaletteResponse = await extract(req) r, g, b = resp.dominant assert r > 200, "dominant R channel should be high for solid red" assert g < 80 assert b < 80 @pytest.mark.asyncio async def test_palette_color_count_respected() -> None: req = PaletteRequest(image_base64=_make_png_b64((0, 128, 255)), color_count=4) resp = await extract(req) assert resp.color_count <= 4 assert len(resp.palette) == resp.color_count @pytest.mark.asyncio async def test_palette_each_entry_is_rgb_triple() -> None: req = PaletteRequest(image_base64=_make_png_b64((100, 200, 50)), color_count=6) resp = await extract(req) for entry in resp.palette: assert len(entry) == 3 assert all(0 <= c <= 255 for c in entry) @pytest.mark.asyncio async def test_palette_missing_image_raises_400() -> None: from fastapi import HTTPException with pytest.raises(HTTPException) as exc_info: await extract(PaletteRequest()) assert exc_info.value.status_code == 400 @pytest.mark.asyncio async def test_palette_bad_base64_raises_400() -> None: from fastapi import HTTPException with pytest.raises(HTTPException) as exc_info: await extract(PaletteRequest(image_base64="!!!notbase64!!!")) assert exc_info.value.status_code == 400 def test_palette_via_test_client(client, red_png_b64) -> None: resp = client.post("/palette/extract", json={"image_base64": red_png_b64, "color_count": 5}) assert resp.status_code == 200 body = resp.json() assert "dominant" in body assert len(body["dominant"]) == 3 assert body["color_count"] <= 5 assert len(body["palette"]) == body["color_count"]