1
1
from __future__ import annotations
2
2
3
3
import base64
4
- from typing import TYPE_CHECKING , Literal , Optional , cast
4
+ from typing import TYPE_CHECKING , Literal , Optional
5
5
6
6
import openai
7
- from attrs import define , field
7
+ from attrs import define , field , fields_dict
8
8
9
- from griptape .artifacts import ImageArtifact
10
9
from griptape .drivers .image_generation import BaseImageGenerationDriver
11
10
from griptape .utils .decorators import lazy_property
12
11
13
12
if TYPE_CHECKING :
14
13
from openai .types .images_response import ImagesResponse
15
14
15
+ from griptape .artifacts import ImageArtifact
16
+
16
17
17
18
@define
18
19
class OpenAiImageGenerationDriver (BaseImageGenerationDriver ):
@@ -32,49 +33,106 @@ class OpenAiImageGenerationDriver(BaseImageGenerationDriver):
32
33
dall-e-3: [1024x1024, 1024x1792, 1792x1024]
33
34
response_format: The response format. Currently only supports 'b64_json' which will return
34
35
a base64 encoded image in a JSON object.
36
+ background: Optional and only supported for gpt-image-1. Can be either 'transparent', 'opaque', or 'auto'.
37
+ moderation: Optional and only supported for gpt-image-1. Can be either 'low' or 'auto'.
38
+ output_compression: Optional and only supported for gpt-image-1. Can be an integer between 0 and 100.
39
+ output_format: Optional and only supported for gpt-image-1. Can be either 'png' or 'jpeg'.
35
40
"""
36
41
37
42
api_type : Optional [str ] = field (default = openai .api_type , kw_only = True )
38
43
api_version : Optional [str ] = field (default = openai .api_version , kw_only = True , metadata = {"serializable" : True })
39
44
base_url : Optional [str ] = field (default = None , kw_only = True , metadata = {"serializable" : True })
40
45
api_key : Optional [str ] = field (default = None , kw_only = True , metadata = {"serializable" : False })
41
46
organization : Optional [str ] = field (default = openai .organization , kw_only = True , metadata = {"serializable" : True })
42
- style : Optional [str ] = field (default = None , kw_only = True , metadata = {"serializable" : True })
43
- quality : Literal ["standard" , "hd" ] = field (
44
- default = "standard" ,
47
+ style : Optional [Literal ["vivid" , "natural" ]] = field (
48
+ default = None , kw_only = True , metadata = {"serializable" : True , "model_allowlist" : ["dall-e-3" ]}
49
+ )
50
+ quality : Optional [Literal ["standard" , "hd" , "low" , "medium" , "high" , "auto" ]] = field (
51
+ default = None ,
52
+ kw_only = True ,
53
+ metadata = {"serializable" : True },
54
+ )
55
+ image_size : Optional [Literal ["256x256" , "512x512" , "1024x1024" , "1024x1792" , "1792x1024" ]] = field (
56
+ default = None ,
45
57
kw_only = True ,
46
58
metadata = {"serializable" : True },
47
59
)
48
- image_size : Literal ["256x256" , "512x512" , "1024x1024" , "1024x1792" , "1792x1024" ] = field (
49
- default = "1024x1024" , kw_only = True , metadata = {"serializable" : True }
60
+ response_format : Literal ["b64_json" ] = field (
61
+ default = "b64_json" ,
62
+ kw_only = True ,
63
+ metadata = {"serializable" : True , "model_denylist" : ["gpt-image-1" ]},
64
+ )
65
+ background : Optional [Literal ["transparent" , "opaque" , "auto" ]] = field (
66
+ default = None ,
67
+ kw_only = True ,
68
+ metadata = {"serializable" : True , "model_allowlist" : ["gpt-image-1" ]},
69
+ )
70
+ moderation : Optional [Literal ["low" , "auto" ]] = field (
71
+ default = None ,
72
+ kw_only = True ,
73
+ metadata = {"serializable" : True , "model_allowlist" : ["gpt-image-1" ]},
74
+ )
75
+ output_compression : Optional [int ] = field (
76
+ default = None ,
77
+ kw_only = True ,
78
+ metadata = {"serializable" : True , "model_allowlist" : ["gpt-image-1" ]},
79
+ )
80
+ output_format : Optional [Literal ["png" , "jpeg" ]] = field (
81
+ default = None ,
82
+ kw_only = True ,
83
+ metadata = {"serializable" : True , "model_allowlist" : ["gpt-image-1" ]},
50
84
)
51
- response_format : Literal ["b64_json" ] = field (default = "b64_json" , kw_only = True , metadata = {"serializable" : True })
52
85
_client : Optional [openai .OpenAI ] = field (
53
86
default = None , kw_only = True , alias = "client" , metadata = {"serializable" : False }
54
87
)
55
88
89
+ @image_size .validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
90
+ def validate_image_size (self , attribute : str , value : str | None ) -> None :
91
+ """Validates the image size based on the model.
92
+
93
+ Must be one of `1024x1024`, `1536x1024` (landscape), `1024x1536` (portrait), or `auto` (default value) for
94
+ `gpt-image-1`, one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`, and
95
+ one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3`.
96
+
97
+ """
98
+ if value is None :
99
+ return
100
+
101
+ if self .model .startswith ("gpt-image" ):
102
+ allowed_sizes = ("1024x1024" , "1536x1024" , "1024x1536" , "auto" )
103
+ elif self .model == "dall-e-2" :
104
+ allowed_sizes = ("256x256" , "512x512" , "1024x1024" )
105
+ elif self .model == "dall-e-3" :
106
+ allowed_sizes = ("1024x1024" , "1792x1024" , "1024x1792" )
107
+ else :
108
+ raise NotImplementedError (f"Image size validation not implemented for model { self .model } " )
109
+
110
+ if value is not None and value not in allowed_sizes :
111
+ raise ValueError (f"Image size, { value } , must be one of the following: { allowed_sizes } " )
112
+
56
113
@lazy_property ()
57
114
def client (self ) -> openai .OpenAI :
58
115
return openai .OpenAI (api_key = self .api_key , base_url = self .base_url , organization = self .organization )
59
116
60
117
def try_text_to_image (self , prompts : list [str ], negative_prompts : Optional [list [str ]] = None ) -> ImageArtifact :
61
118
prompt = ", " .join (prompts )
62
119
63
- additional_params = {}
64
-
65
- if self .style :
66
- additional_params ["style" ] = self .style
67
-
68
- if self .quality :
69
- additional_params ["quality" ] = self .quality
70
-
71
120
response = self .client .images .generate (
72
121
model = self .model ,
73
122
prompt = prompt ,
74
- size = self .image_size ,
75
- response_format = self .response_format ,
76
123
n = 1 ,
77
- ** additional_params ,
124
+ ** self ._build_model_params (
125
+ {
126
+ "size" : "image_size" ,
127
+ "quality" : "quality" ,
128
+ "style" : "style" ,
129
+ "response_format" : "response_format" ,
130
+ "background" : "background" ,
131
+ "moderation" : "moderation" ,
132
+ "output_compression" : "output_compression" ,
133
+ "output_format" : "output_format" ,
134
+ }
135
+ ),
78
136
)
79
137
80
138
return self ._parse_image_response (response , prompt )
@@ -85,13 +143,18 @@ def try_image_variation(
85
143
image : ImageArtifact ,
86
144
negative_prompts : Optional [list [str ]] = None ,
87
145
) -> ImageArtifact :
88
- image_size = self . _dall_e_2_filter_image_size ( " variation" )
146
+ """Creates a variation of an image.
89
147
148
+ Only supported by for dall-e-2. Requires image size to be one of the following:
149
+ [256x256, 512x512, 1024x1024]
150
+ """
151
+ if self .model != "dall-e-2" :
152
+ raise NotImplementedError ("Image variation only supports dall-e-2" )
90
153
response = self .client .images .create_variation (
91
154
image = image .value ,
92
155
n = 1 ,
93
156
response_format = self .response_format ,
94
- size = image_size ,
157
+ size = self . image_size , # pyright: ignore[reportArgumentType]
95
158
)
96
159
97
160
return self ._parse_image_response (response , "" )
@@ -103,15 +166,17 @@ def try_image_inpainting(
103
166
mask : ImageArtifact ,
104
167
negative_prompts : Optional [list [str ]] = None ,
105
168
) -> ImageArtifact :
106
- image_size = self ._dall_e_2_filter_image_size ("inpainting" )
107
-
108
169
prompt = ", " .join (prompts )
109
170
response = self .client .images .edit (
110
171
prompt = prompt ,
111
172
image = image .value ,
112
173
mask = mask .value ,
113
- response_format = self .response_format ,
114
- size = image_size ,
174
+ ** self ._build_model_params (
175
+ {
176
+ "size" : "image_size" ,
177
+ "response_format" : "response_format" ,
178
+ }
179
+ ),
115
180
)
116
181
117
182
return self ._parse_image_response (response , prompt )
@@ -125,29 +190,45 @@ def try_image_outpainting(
125
190
) -> ImageArtifact :
126
191
raise NotImplementedError (f"{ self .__class__ .__name__ } does not support outpainting" )
127
192
128
- def _image_size_to_ints (self , image_size : str ) -> list [int ]:
129
- return [int (x ) for x in image_size .split ("x" )]
130
-
131
- def _dall_e_2_filter_image_size (self , method : str ) -> Literal ["256x256" , "512x512" , "1024x1024" ]:
132
- if self .model != "dall-e-2" :
133
- raise NotImplementedError (f"{ method } only supports dall-e-2" )
134
-
135
- if self .image_size not in {"256x256" , "512x512" , "1024x1024" }:
136
- raise ValueError (f"support image sizes for { method } are 256x256, 512x512, and 1024x1024" )
137
-
138
- return cast ("Literal['256x256', '512x512', '1024x1024']" , self .image_size )
139
-
140
193
def _parse_image_response (self , response : ImagesResponse , prompt : str ) -> ImageArtifact :
194
+ from griptape .loaders .image_loader import ImageLoader
195
+
141
196
if response .data is None or response .data [0 ] is None or response .data [0 ].b64_json is None :
142
197
raise Exception ("Failed to generate image" )
143
198
144
199
image_data = base64 .b64decode (response .data [0 ].b64_json )
145
- image_dimensions = self ._image_size_to_ints (self .image_size )
146
-
147
- return ImageArtifact (
148
- value = image_data ,
149
- format = "png" ,
150
- width = image_dimensions [0 ],
151
- height = image_dimensions [1 ],
152
- meta = {"model" : self .model , "prompt" : prompt },
153
- )
200
+
201
+ image_artifact = ImageLoader ().parse (image_data )
202
+
203
+ image_artifact .meta ["prompt" ] = prompt
204
+ image_artifact .meta ["model" ] = self .model
205
+
206
+ return image_artifact
207
+
208
+ def _build_model_params (self , values : dict ) -> dict :
209
+ """Builds parameters while considering field metadata and None values.
210
+
211
+ Args:
212
+ values: A dictionary mapping parameter names to field names.
213
+
214
+ Field will be added to the params dictionary if all conditions are met:
215
+ - The field value is not None
216
+ - The model_allowlist is None or the model is in the allowlist
217
+ - The model_denylist is None or the model is not in the denylist
218
+ """
219
+ params = {}
220
+
221
+ fields = fields_dict (self .__class__ )
222
+ for param_name , field_name in values .items ():
223
+ metadata = fields [field_name ].metadata
224
+ model_allowlist = metadata .get ("model_allowlist" )
225
+ model_denylist = metadata .get ("model_denylist" )
226
+
227
+ field_value = getattr (self , field_name , None )
228
+
229
+ allowlist_condition = model_allowlist is None or self .model in model_allowlist
230
+ denylist_condition = model_denylist is None or self .model not in model_denylist
231
+
232
+ if field_value is not None and allowlist_condition and denylist_condition :
233
+ params [param_name ] = field_value
234
+ return params
0 commit comments