Skip to content

Commit c8cf6b2

Browse files
authored
Automatic Serialization/deserialization of images (#361)
* image input/output * format
1 parent 0c19861 commit c8cf6b2

File tree

3 files changed

+110
-0
lines changed

3 files changed

+110
-0
lines changed

src/litserve/schema/__init__.py

Whitespace-only changes.

src/litserve/schema/image.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import base64
2+
from io import BytesIO
3+
from typing import TYPE_CHECKING, Any
4+
5+
from pydantic import BaseModel, field_serializer, field_validator
6+
7+
if TYPE_CHECKING:
8+
from PIL import Image
9+
10+
11+
class ImageInput(BaseModel):
12+
image_data: str
13+
14+
@field_validator("image_data")
15+
def validate_base64(cls, value: str) -> str: # noqa
16+
"""Ensure the string is a valid Base64."""
17+
try:
18+
base64.b64decode(value)
19+
except base64.binascii.Error:
20+
raise ValueError("Invalid Base64 string.")
21+
return value
22+
23+
def get_image(self) -> "Image.Image":
24+
"""Decode the Base64 string and return a PIL Image object."""
25+
try:
26+
from PIL import Image, UnidentifiedImageError
27+
except ImportError:
28+
raise ImportError("Pillow is required to use the ImageInput schema. Install it with `pip install Pillow`.")
29+
try:
30+
decoded_data = base64.b64decode(self.image_data)
31+
return Image.open(BytesIO(decoded_data))
32+
except base64.binascii.Error as e:
33+
raise ValueError(f"Error decoding Base64 string: {e}")
34+
except UnidentifiedImageError as e:
35+
raise ValueError(f"Error loading image from decoded data: {e}")
36+
37+
38+
class ImageOutput(BaseModel):
39+
image: Any
40+
41+
@field_serializer("image")
42+
def serialize_image(self, image: Any, _info):
43+
"""
44+
Serialize a PIL Image into a base64 string.
45+
Args:
46+
image (Any): The image object to serialize.
47+
_info: Metadata passed during serialization (not used here).
48+
49+
Returns:
50+
str: Base64-encoded image string.
51+
"""
52+
try:
53+
from PIL import Image
54+
except ImportError:
55+
raise ImportError("Pillow is required to use the ImageOutput schema. Install it with `pip install Pillow`.")
56+
57+
if not isinstance(image, Image.Image):
58+
raise TypeError(f"Expected a PIL Image, got {type(image)}")
59+
60+
# Save the image to a BytesIO buffer
61+
buffer = BytesIO()
62+
image.save(buffer, format="PNG") # Default format is PNG
63+
buffer.seek(0)
64+
65+
# Encode the buffer content to base64
66+
base64_bytes = base64.b64encode(buffer.read())
67+
68+
# Decode to string for JSON serialization
69+
return base64_bytes.decode("utf-8")

tests/test_schema.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import base64
2+
import io
3+
import os
4+
5+
import numpy as np
6+
from fastapi.testclient import TestClient
7+
from PIL import Image
8+
9+
import litserve as ls
10+
from litserve.schema.image import ImageInput, ImageOutput
11+
from litserve.utils import wrap_litserve_start
12+
13+
14+
class ImageAPI(ls.LitAPI):
15+
def setup(self, device):
16+
self.model = lambda x: np.array(x) * 2
17+
18+
def decode_request(self, request: ImageInput):
19+
return request.get_image()
20+
21+
def predict(self, x):
22+
return self.model(x)
23+
24+
def encode_response(self, numpy_image) -> ImageOutput:
25+
output = Image.fromarray(np.uint8(numpy_image)).convert("RGB")
26+
return ImageOutput(image=output)
27+
28+
29+
def test_image_input_output(tmpdir):
30+
path = os.path.join(tmpdir, "test.png")
31+
server = ls.LitServer(ImageAPI(), accelerator="cpu", devices=1, workers_per_device=1)
32+
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
33+
Image.new("RGB", (32, 32)).save(path)
34+
with open(path, "rb") as image_file:
35+
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
36+
response = client.post("/predict", json={"image_data": encoded_string})
37+
38+
assert response.status_code == 200, f"Unexpected status code: {response.status_code}"
39+
image_data = response.json()["image"]
40+
image = Image.open(io.BytesIO(base64.b64decode(image_data)))
41+
assert image.size == (32, 32), f"Unexpected image size: {image.size}"

0 commit comments

Comments
 (0)