Skip to content

Commit 2a6f2a9

Browse files
authored
feat: add support for gpt-image-1 (#1921)
1 parent 0d049da commit 2a6f2a9

File tree

8 files changed

+208
-85
lines changed

8 files changed

+208
-85
lines changed

docs/griptape-framework/drivers/image-generation-drivers.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ This Driver supports negative prompts. When provided, the image generation reque
113113

114114
The [OpenAI Image Generation Driver](../../reference/griptape/drivers/image_generation/openai_image_generation_driver.md) provides access to OpenAI image generation models. Like other OpenAI Drivers, the image generation Driver will implicitly load an API key in the `OPENAI_API_KEY` environment variable if one is not explicitly provided.
115115

116-
This Driver supports image generation configurations like style presets, image quality preference, and image size. For details on supported configuration values, see the [OpenAI documentation](https://platform.openai.com/docs/guides/images/introduction).
116+
This Driver supports image generation configurations like style presets, image quality preference, and image size. For details on supported configuration values, see the [OpenAI documentation](https://platform.openai.com/docs/guides/image-generation).
117117

118118
=== "Code"
119119

docs/griptape-framework/drivers/src/image_generation_drivers_1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from griptape.tools import PromptImageGenerationTool
44

55
driver = OpenAiImageGenerationDriver(
6-
model="dall-e-2",
6+
model="gpt-image-1",
77
)
88

99
agent = Agent(
Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,9 @@
11
from griptape.drivers.image_generation.openai import OpenAiImageGenerationDriver
22
from griptape.structures import Agent
3-
from griptape.tools import PromptImageGenerationTool
3+
from griptape.tools import FileManagerTool, PromptImageGenerationTool
44

5-
driver = OpenAiImageGenerationDriver(
6-
model="dall-e-2",
7-
image_size="512x512",
8-
)
5+
driver = OpenAiImageGenerationDriver(model="gpt-image-1")
96

7+
agent = Agent(tools=[PromptImageGenerationTool(image_generation_driver=driver, off_prompt=True), FileManagerTool()])
108

11-
agent = Agent(
12-
tools=[
13-
PromptImageGenerationTool(image_generation_driver=driver),
14-
]
15-
)
16-
17-
agent.run("Generate a watercolor painting of a dog riding a skateboard")
9+
agent.run("Generate a watercolor painting of a dog riding a skateboard and save it to dog.png")
Lines changed: 128 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
from __future__ import annotations
22

33
import base64
4-
from typing import TYPE_CHECKING, Literal, Optional, cast
4+
from typing import TYPE_CHECKING, Literal, Optional
55

66
import openai
7-
from attrs import define, field
7+
from attrs import define, field, fields_dict
88

9-
from griptape.artifacts import ImageArtifact
109
from griptape.drivers.image_generation import BaseImageGenerationDriver
1110
from griptape.utils.decorators import lazy_property
1211

1312
if TYPE_CHECKING:
1413
from openai.types.images_response import ImagesResponse
1514

15+
from griptape.artifacts import ImageArtifact
16+
1617

1718
@define
1819
class OpenAiImageGenerationDriver(BaseImageGenerationDriver):
@@ -32,49 +33,106 @@ class OpenAiImageGenerationDriver(BaseImageGenerationDriver):
3233
dall-e-3: [1024x1024, 1024x1792, 1792x1024]
3334
response_format: The response format. Currently only supports 'b64_json' which will return
3435
a base64 encoded image in a JSON object.
36+
background: Optional and only supported for gpt-image-1. Can be either 'transparent', 'opaque', or 'auto'.
37+
moderation: Optional and only supported for gpt-image-1. Can be either 'low' or 'auto'.
38+
output_compression: Optional and only supported for gpt-image-1. Can be an integer between 0 and 100.
39+
output_format: Optional and only supported for gpt-image-1. Can be either 'png' or 'jpeg'.
3540
"""
3641

3742
api_type: Optional[str] = field(default=openai.api_type, kw_only=True)
3843
api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True})
3944
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
4045
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
4146
organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True})
42-
style: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
43-
quality: Literal["standard", "hd"] = field(
44-
default="standard",
47+
style: Optional[Literal["vivid", "natural"]] = field(
48+
default=None, kw_only=True, metadata={"serializable": True, "model_allowlist": ["dall-e-3"]}
49+
)
50+
quality: Optional[Literal["standard", "hd", "low", "medium", "high", "auto"]] = field(
51+
default=None,
52+
kw_only=True,
53+
metadata={"serializable": True},
54+
)
55+
image_size: Optional[Literal["256x256", "512x512", "1024x1024", "1024x1792", "1792x1024"]] = field(
56+
default=None,
4557
kw_only=True,
4658
metadata={"serializable": True},
4759
)
48-
image_size: Literal["256x256", "512x512", "1024x1024", "1024x1792", "1792x1024"] = field(
49-
default="1024x1024", kw_only=True, metadata={"serializable": True}
60+
response_format: Literal["b64_json"] = field(
61+
default="b64_json",
62+
kw_only=True,
63+
metadata={"serializable": True, "model_denylist": ["gpt-image-1"]},
64+
)
65+
background: Optional[Literal["transparent", "opaque", "auto"]] = field(
66+
default=None,
67+
kw_only=True,
68+
metadata={"serializable": True, "model_allowlist": ["gpt-image-1"]},
69+
)
70+
moderation: Optional[Literal["low", "auto"]] = field(
71+
default=None,
72+
kw_only=True,
73+
metadata={"serializable": True, "model_allowlist": ["gpt-image-1"]},
74+
)
75+
output_compression: Optional[int] = field(
76+
default=None,
77+
kw_only=True,
78+
metadata={"serializable": True, "model_allowlist": ["gpt-image-1"]},
79+
)
80+
output_format: Optional[Literal["png", "jpeg"]] = field(
81+
default=None,
82+
kw_only=True,
83+
metadata={"serializable": True, "model_allowlist": ["gpt-image-1"]},
5084
)
51-
response_format: Literal["b64_json"] = field(default="b64_json", kw_only=True, metadata={"serializable": True})
5285
_client: Optional[openai.OpenAI] = field(
5386
default=None, kw_only=True, alias="client", metadata={"serializable": False}
5487
)
5588

89+
@image_size.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
90+
def validate_image_size(self, attribute: str, value: str | None) -> None:
91+
"""Validates the image size based on the model.
92+
93+
Must be one of `1024x1024`, `1536x1024` (landscape), `1024x1536` (portrait), or `auto` (default value) for
94+
`gpt-image-1`, one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`, and
95+
one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3`.
96+
97+
"""
98+
if value is None:
99+
return
100+
101+
if self.model.startswith("gpt-image"):
102+
allowed_sizes = ("1024x1024", "1536x1024", "1024x1536", "auto")
103+
elif self.model == "dall-e-2":
104+
allowed_sizes = ("256x256", "512x512", "1024x1024")
105+
elif self.model == "dall-e-3":
106+
allowed_sizes = ("1024x1024", "1792x1024", "1024x1792")
107+
else:
108+
raise NotImplementedError(f"Image size validation not implemented for model {self.model}")
109+
110+
if value is not None and value not in allowed_sizes:
111+
raise ValueError(f"Image size, {value}, must be one of the following: {allowed_sizes}")
112+
56113
@lazy_property()
57114
def client(self) -> openai.OpenAI:
58115
return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)
59116

60117
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
61118
prompt = ", ".join(prompts)
62119

63-
additional_params = {}
64-
65-
if self.style:
66-
additional_params["style"] = self.style
67-
68-
if self.quality:
69-
additional_params["quality"] = self.quality
70-
71120
response = self.client.images.generate(
72121
model=self.model,
73122
prompt=prompt,
74-
size=self.image_size,
75-
response_format=self.response_format,
76123
n=1,
77-
**additional_params,
124+
**self._build_model_params(
125+
{
126+
"size": "image_size",
127+
"quality": "quality",
128+
"style": "style",
129+
"response_format": "response_format",
130+
"background": "background",
131+
"moderation": "moderation",
132+
"output_compression": "output_compression",
133+
"output_format": "output_format",
134+
}
135+
),
78136
)
79137

80138
return self._parse_image_response(response, prompt)
@@ -85,13 +143,18 @@ def try_image_variation(
85143
image: ImageArtifact,
86144
negative_prompts: Optional[list[str]] = None,
87145
) -> ImageArtifact:
88-
image_size = self._dall_e_2_filter_image_size("variation")
146+
"""Creates a variation of an image.
89147
148+
Only supported by for dall-e-2. Requires image size to be one of the following:
149+
[256x256, 512x512, 1024x1024]
150+
"""
151+
if self.model != "dall-e-2":
152+
raise NotImplementedError("Image variation only supports dall-e-2")
90153
response = self.client.images.create_variation(
91154
image=image.value,
92155
n=1,
93156
response_format=self.response_format,
94-
size=image_size,
157+
size=self.image_size, # pyright: ignore[reportArgumentType]
95158
)
96159

97160
return self._parse_image_response(response, "")
@@ -103,15 +166,17 @@ def try_image_inpainting(
103166
mask: ImageArtifact,
104167
negative_prompts: Optional[list[str]] = None,
105168
) -> ImageArtifact:
106-
image_size = self._dall_e_2_filter_image_size("inpainting")
107-
108169
prompt = ", ".join(prompts)
109170
response = self.client.images.edit(
110171
prompt=prompt,
111172
image=image.value,
112173
mask=mask.value,
113-
response_format=self.response_format,
114-
size=image_size,
174+
**self._build_model_params(
175+
{
176+
"size": "image_size",
177+
"response_format": "response_format",
178+
}
179+
),
115180
)
116181

117182
return self._parse_image_response(response, prompt)
@@ -125,29 +190,45 @@ def try_image_outpainting(
125190
) -> ImageArtifact:
126191
raise NotImplementedError(f"{self.__class__.__name__} does not support outpainting")
127192

128-
def _image_size_to_ints(self, image_size: str) -> list[int]:
129-
return [int(x) for x in image_size.split("x")]
130-
131-
def _dall_e_2_filter_image_size(self, method: str) -> Literal["256x256", "512x512", "1024x1024"]:
132-
if self.model != "dall-e-2":
133-
raise NotImplementedError(f"{method} only supports dall-e-2")
134-
135-
if self.image_size not in {"256x256", "512x512", "1024x1024"}:
136-
raise ValueError(f"support image sizes for {method} are 256x256, 512x512, and 1024x1024")
137-
138-
return cast("Literal['256x256', '512x512', '1024x1024']", self.image_size)
139-
140193
def _parse_image_response(self, response: ImagesResponse, prompt: str) -> ImageArtifact:
194+
from griptape.loaders.image_loader import ImageLoader
195+
141196
if response.data is None or response.data[0] is None or response.data[0].b64_json is None:
142197
raise Exception("Failed to generate image")
143198

144199
image_data = base64.b64decode(response.data[0].b64_json)
145-
image_dimensions = self._image_size_to_ints(self.image_size)
146-
147-
return ImageArtifact(
148-
value=image_data,
149-
format="png",
150-
width=image_dimensions[0],
151-
height=image_dimensions[1],
152-
meta={"model": self.model, "prompt": prompt},
153-
)
200+
201+
image_artifact = ImageLoader().parse(image_data)
202+
203+
image_artifact.meta["prompt"] = prompt
204+
image_artifact.meta["model"] = self.model
205+
206+
return image_artifact
207+
208+
def _build_model_params(self, values: dict) -> dict:
209+
"""Builds parameters while considering field metadata and None values.
210+
211+
Args:
212+
values: A dictionary mapping parameter names to field names.
213+
214+
Field will be added to the params dictionary if all conditions are met:
215+
- The field value is not None
216+
- The model_allowlist is None or the model is in the allowlist
217+
- The model_denylist is None or the model is not in the denylist
218+
"""
219+
params = {}
220+
221+
fields = fields_dict(self.__class__)
222+
for param_name, field_name in values.items():
223+
metadata = fields[field_name].metadata
224+
model_allowlist = metadata.get("model_allowlist")
225+
model_denylist = metadata.get("model_denylist")
226+
227+
field_value = getattr(self, field_name, None)
228+
229+
allowlist_condition = model_allowlist is None or self.model in model_allowlist
230+
denylist_condition = model_denylist is None or self.model not in model_denylist
231+
232+
if field_value is not None and allowlist_condition and denylist_condition:
233+
params[param_name] = field_value
234+
return params

tests/unit/configs/drivers/test_azure_openai_drivers_config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,16 @@ def test_to_dict(self, config):
5858
"image_generation_driver": {
5959
"api_version": "2024-02-01",
6060
"base_url": None,
61+
"background": None,
6162
"image_size": "512x512",
6263
"model": "dall-e-2",
64+
"moderation": None,
6365
"azure_deployment": "dall-e-2",
6466
"azure_endpoint": "http://localhost:8080",
6567
"organization": None,
66-
"quality": "standard",
68+
"output_compression": None,
69+
"output_format": None,
70+
"quality": None,
6771
"response_format": "b64_json",
6872
"style": None,
6973
"type": "AzureOpenAiImageGenerationDriver",

tests/unit/configs/drivers/test_openai_driver_config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,14 @@ def test_to_dict(self, config):
4747
"image_generation_driver": {
4848
"api_version": None,
4949
"base_url": None,
50+
"background": None,
5051
"image_size": "512x512",
5152
"model": "dall-e-2",
5253
"organization": None,
53-
"quality": "standard",
54+
"output_compression": None,
55+
"output_format": None,
56+
"moderation": None,
57+
"quality": None,
5458
"response_format": "b64_json",
5559
"style": None,
5660
"type": "OpenAiImageGenerationDriver",

tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,37 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Optional
14
from unittest.mock import Mock
25

6+
import PIL.Image
37
import pytest
48

59
from griptape.drivers.image_generation.openai import AzureOpenAiImageGenerationDriver
610

11+
if TYPE_CHECKING:
12+
import io
13+
14+
15+
@pytest.fixture(autouse=True)
16+
def _patch_pillow_open(mocker):
17+
"""Stub out PIL.Image.open so no real decoding is attempted."""
18+
19+
class _FakeImage:
20+
def __init__(self) -> None:
21+
self.format: str = "PNG"
22+
self.width: int = 512
23+
self.height: int = 512
24+
25+
def save(self, fp: io.BytesIO, *, _: Optional[str] = None) -> None:
26+
fp.write(b"image data")
27+
28+
mocker.patch.object(
29+
PIL.Image,
30+
"open",
31+
side_effect=lambda *_, **__: _FakeImage(),
32+
autospec=True,
33+
)
34+
735

836
class TestAzureOpenAiImageGenerationDriver:
937
@pytest.fixture()
@@ -13,27 +41,18 @@ def driver(self):
1341
client=Mock(),
1442
azure_endpoint="https://dalle.example.com",
1543
azure_deployment="dalle-deployment",
16-
image_size="512x512",
44+
image_size="1024x1024",
1745
)
1846

1947
def test_init(self, driver):
2048
assert driver
2149
assert (
2250
AzureOpenAiImageGenerationDriver(
23-
model="dall-e-3", client=Mock(), azure_endpoint="https://dalle.example.com", image_size="512x512"
51+
model="dall-e-3", client=Mock(), azure_endpoint="https://dalle.example.com", image_size="1024x1024"
2452
).azure_deployment
2553
== "dall-e-3"
2654
)
2755

28-
def test_init_requires_endpoint(self):
29-
with pytest.raises(TypeError):
30-
AzureOpenAiImageGenerationDriver(
31-
model="dall-e-3",
32-
client=Mock(),
33-
azure_deployment="dalle-deployment",
34-
image_size="512x512",
35-
) # pyright: ignore[reportCallIssues]
36-
3756
def test_try_text_to_image(self, driver):
3857
driver.client.images.generate.return_value = Mock(data=[Mock(b64_json=b"aW1hZ2UgZGF0YQ==")])
3958

0 commit comments

Comments
 (0)