9
9
10
10
from smithy_core import URI
11
11
from smithy_core .codecs import Codec
12
+ from smithy_core .exceptions import SerializationError
12
13
from smithy_core .schemas import Schema
13
14
from smithy_core .serializers import (
14
15
InterceptingSerializer ,
24
25
HTTPQueryTrait ,
25
26
HTTPTrait ,
26
27
MediaTypeTrait ,
28
+ RequiresLengthTrait ,
27
29
TimestampFormatTrait ,
28
30
)
29
31
from smithy_core .types import PathPattern , TimestampFormat
30
32
from smithy_core .utils import serialize_float
31
33
32
- from . import tuples_to_fields
34
+ from . import Field , tuples_to_fields
33
35
from .aio import HTTPRequest as _HTTPRequest
34
36
from .aio import HTTPResponse as _HTTPResponse
35
37
from .aio .interfaces import HTTPRequest , HTTPResponse
43
45
__all__ = ["HTTPRequestSerializer" , "HTTPResponseSerializer" ]
44
46
45
47
48
+ # TODO: refactor this to share code with response serializer
46
49
class HTTPRequestSerializer (SpecificShapeSerializer ):
47
50
"""Binds a serializable shape to an HTTP request.
48
51
@@ -82,8 +85,12 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
82
85
host_prefix = self ._endpoint_trait .host_prefix
83
86
84
87
content_type = self ._payload_codec .media_type
88
+ content_length : int | None = None
89
+ content_length_required = False
90
+
85
91
binding_matcher = RequestBindingMatcher (schema )
86
92
if (payload_member := binding_matcher .payload_member ) is not None :
93
+ content_length_required = RequiresLengthTrait in payload_member
87
94
if payload_member .shape_type in (
88
95
ShapeType .BLOB ,
89
96
ShapeType .STRING ,
@@ -105,6 +112,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
105
112
)
106
113
yield binding_serializer
107
114
payload = payload_serializer .payload
115
+ try :
116
+ content_length = len (payload )
117
+ except TypeError :
118
+ pass
108
119
else :
109
120
if (media_type := payload_member .get_trait (MediaTypeTrait )) is not None :
110
121
content_type = media_type .value
@@ -117,6 +128,7 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
117
128
binding_matcher ,
118
129
)
119
130
yield binding_serializer
131
+ content_length = payload .tell ()
120
132
else :
121
133
payload = BytesIO ()
122
134
payload_serializer = self ._payload_codec .create_serializer (payload )
@@ -131,8 +143,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
131
143
binding_matcher ,
132
144
)
133
145
yield binding_serializer
146
+ content_length = payload .tell ()
134
147
else :
135
148
content_type = None
149
+ content_length = 0
136
150
binding_serializer = HTTPRequestBindingSerializer (
137
151
payload_serializer ,
138
152
self ._http_trait .path ,
@@ -141,15 +155,24 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
141
155
)
142
156
yield binding_serializer
143
157
144
- if (
145
- seek := getattr (payload , "seek" , None )
146
- ) is not None and not iscoroutinefunction (seek ):
147
- seek (0 )
148
-
158
+ self ._seek (payload , 0 )
149
159
headers = binding_serializer .header_serializer .headers
150
160
if content_type is not None :
151
161
headers .append (("content-type" , content_type ))
152
162
163
+ if content_length is not None :
164
+ headers .append (("content-length" , str (content_length )))
165
+
166
+ fields = tuples_to_fields (headers )
167
+ if content_length_required and "content-length" not in fields :
168
+ content_length = self ._compute_content_length (payload )
169
+ if content_length is None :
170
+ raise SerializationError (
171
+ "This operation requires the the content length of the input "
172
+ "stream, but it was not provided and was unable to be computed."
173
+ )
174
+ fields .set_field (Field (name = "content-length" , values = [str (content_length )]))
175
+
153
176
self .result = _HTTPRequest (
154
177
method = self ._http_trait .method ,
155
178
destination = URI (
@@ -160,10 +183,22 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
160
183
prefix = self ._http_trait .query or "" ,
161
184
),
162
185
),
163
- fields = tuples_to_fields ( headers ) ,
186
+ fields = fields ,
164
187
body = payload ,
165
188
)
166
189
190
+ def _seek (self , payload : Any , pos : int , whence : int = 0 ) -> None :
191
+ if (
192
+ seek := getattr (payload , "seek" , None )
193
+ ) is not None and not iscoroutinefunction (seek ):
194
+ seek (pos , whence )
195
+
196
+ def _compute_content_length (self , payload : Any ) -> int | None :
197
+ content_length = self ._seek (payload , 0 , 2 )
198
+ if content_length is not None :
199
+ self ._seek (payload , 0 , 0 )
200
+ return content_length
201
+
167
202
168
203
class HTTPRequestBindingSerializer (InterceptingSerializer ):
169
204
"""Delegates HTTP request bindings to binding-location-specific serializers."""
@@ -235,8 +270,12 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
235
270
binding_serializer : HTTPResponseBindingSerializer
236
271
237
272
content_type : str | None = self ._payload_codec .media_type
273
+ content_length : int | None = None
274
+ content_length_required = False
275
+
238
276
binding_matcher = ResponseBindingMatcher (schema )
239
277
if (payload_member := binding_matcher .payload_member ) is not None :
278
+ content_length_required = RequiresLengthTrait in payload_member
240
279
if payload_member .shape_type in (ShapeType .BLOB , ShapeType .STRING ):
241
280
if (media_type := payload_member .get_trait (MediaTypeTrait )) is not None :
242
281
content_type = media_type .value
@@ -250,6 +289,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
250
289
)
251
290
yield binding_serializer
252
291
payload = payload_serializer .payload
292
+ try :
293
+ content_length = len (payload )
294
+ except TypeError :
295
+ pass
253
296
else :
254
297
if (media_type := payload_member .get_trait (MediaTypeTrait )) is not None :
255
298
content_type = media_type .value
@@ -259,6 +302,7 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
259
302
payload_serializer , binding_matcher
260
303
)
261
304
yield binding_serializer
305
+ content_length = payload .tell ()
262
306
else :
263
307
payload = BytesIO ()
264
308
payload_serializer = self ._payload_codec .create_serializer (payload )
@@ -270,23 +314,34 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
270
314
body_serializer , binding_matcher
271
315
)
272
316
yield binding_serializer
317
+ content_length = payload .tell ()
273
318
else :
274
319
content_type = None
320
+ content_length = 0
275
321
binding_serializer = HTTPResponseBindingSerializer (
276
322
payload_serializer ,
277
323
binding_matcher ,
278
324
)
279
325
yield binding_serializer
280
326
281
- if (
282
- seek := getattr (payload , "seek" , None )
283
- ) is not None and not iscoroutinefunction (seek ):
284
- seek (0 )
285
-
327
+ self ._seek (payload , 0 )
286
328
headers = binding_serializer .header_serializer .headers
287
329
if content_type is not None :
288
330
headers .append (("content-type" , content_type ))
289
331
332
+ if content_length is not None :
333
+ headers .append (("content-length" , str (content_length )))
334
+
335
+ fields = tuples_to_fields (headers )
336
+ if content_length_required and "content-length" not in fields :
337
+ content_length = self ._compute_content_length (payload )
338
+ if content_length is None :
339
+ raise SerializationError (
340
+ "This operation requires the the content length of the input "
341
+ "stream, but it was not provided and was unable to be computed."
342
+ )
343
+ fields .set_field (Field (name = "content-length" , values = [str (content_length )]))
344
+
290
345
status = binding_serializer .response_code_serializer .response_code
291
346
if status is None :
292
347
if binding_matcher .response_status > 0 :
@@ -300,6 +355,18 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
300
355
status = status ,
301
356
)
302
357
358
+ def _seek (self , payload : Any , pos : int , whence : int = 0 ) -> int | None :
359
+ if (
360
+ seek := getattr (payload , "seek" , None )
361
+ ) is not None and not iscoroutinefunction (seek ):
362
+ return seek (pos , whence )
363
+
364
+ def _compute_content_length (self , payload : Any ) -> int | None :
365
+ content_length = self ._seek (payload , 0 , 2 )
366
+ if content_length is not None :
367
+ self ._seek (payload , 0 , 0 )
368
+ return content_length
369
+
303
370
304
371
class HTTPResponseBindingSerializer (InterceptingSerializer ):
305
372
"""Delegates HTTP response bindings to binding-location-specific serializers."""
0 commit comments