Skip to content

Commit 963ed7d

Browse files
Compute content length
1 parent 201e52b commit 963ed7d

File tree

3 files changed

+269
-63
lines changed

3 files changed

+269
-63
lines changed

packages/smithy-core/src/smithy_core/traits.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,12 @@ def __post_init__(self):
154154
assert self.document_value is None
155155

156156

157+
@dataclass(init=False, frozen=True)
158+
class RequiresLengthTrait(Trait, id=ShapeID("smithy.api#requiresLength")):
159+
def __post_init__(self):
160+
assert self.document_value is None
161+
162+
157163
@dataclass(init=False, frozen=True)
158164
class UnitTypeTrait(Trait, id=ShapeID("smithy.api#UnitTypeTrait")):
159165
def __post_init__(self):

packages/smithy-http/src/smithy_http/serializers.py

Lines changed: 79 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from smithy_core import URI
1111
from smithy_core.codecs import Codec
12+
from smithy_core.exceptions import SerializationError
1213
from smithy_core.schemas import Schema
1314
from smithy_core.serializers import (
1415
InterceptingSerializer,
@@ -24,12 +25,13 @@
2425
HTTPQueryTrait,
2526
HTTPTrait,
2627
MediaTypeTrait,
28+
RequiresLengthTrait,
2729
TimestampFormatTrait,
2830
)
2931
from smithy_core.types import PathPattern, TimestampFormat
3032
from smithy_core.utils import serialize_float
3133

32-
from . import tuples_to_fields
34+
from . import Field, tuples_to_fields
3335
from .aio import HTTPRequest as _HTTPRequest
3436
from .aio import HTTPResponse as _HTTPResponse
3537
from .aio.interfaces import HTTPRequest, HTTPResponse
@@ -43,6 +45,7 @@
4345
__all__ = ["HTTPRequestSerializer", "HTTPResponseSerializer"]
4446

4547

48+
# TODO: refactor this to share code with response serializer
4649
class HTTPRequestSerializer(SpecificShapeSerializer):
4750
"""Binds a serializable shape to an HTTP request.
4851
@@ -82,8 +85,12 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
8285
host_prefix = self._endpoint_trait.host_prefix
8386

8487
content_type = self._payload_codec.media_type
88+
content_length: int | None = None
89+
content_length_required = False
90+
8591
binding_matcher = RequestBindingMatcher(schema)
8692
if (payload_member := binding_matcher.payload_member) is not None:
93+
content_length_required = RequiresLengthTrait in payload_member
8794
if payload_member.shape_type in (
8895
ShapeType.BLOB,
8996
ShapeType.STRING,
@@ -105,6 +112,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
105112
)
106113
yield binding_serializer
107114
payload = payload_serializer.payload
115+
try:
116+
content_length = len(payload)
117+
except TypeError:
118+
pass
108119
else:
109120
if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None:
110121
content_type = media_type.value
@@ -117,6 +128,7 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
117128
binding_matcher,
118129
)
119130
yield binding_serializer
131+
content_length = payload.tell()
120132
else:
121133
payload = BytesIO()
122134
payload_serializer = self._payload_codec.create_serializer(payload)
@@ -131,8 +143,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
131143
binding_matcher,
132144
)
133145
yield binding_serializer
146+
content_length = payload.tell()
134147
else:
135148
content_type = None
149+
content_length = 0
136150
binding_serializer = HTTPRequestBindingSerializer(
137151
payload_serializer,
138152
self._http_trait.path,
@@ -141,15 +155,24 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
141155
)
142156
yield binding_serializer
143157

144-
if (
145-
seek := getattr(payload, "seek", None)
146-
) is not None and not iscoroutinefunction(seek):
147-
seek(0)
148-
158+
self._seek(payload, 0)
149159
headers = binding_serializer.header_serializer.headers
150160
if content_type is not None:
151161
headers.append(("content-type", content_type))
152162

163+
if content_length is not None:
164+
headers.append(("content-length", str(content_length)))
165+
166+
fields = tuples_to_fields(headers)
167+
if content_length_required and "content-length" not in fields:
168+
content_length = self._compute_content_length(payload)
169+
if content_length is None:
170+
raise SerializationError(
171+
"This operation requires the the content length of the input "
172+
"stream, but it was not provided and was unable to be computed."
173+
)
174+
fields.set_field(Field(name="content-length", values=[str(content_length)]))
175+
153176
self.result = _HTTPRequest(
154177
method=self._http_trait.method,
155178
destination=URI(
@@ -160,10 +183,22 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
160183
prefix=self._http_trait.query or "",
161184
),
162185
),
163-
fields=tuples_to_fields(headers),
186+
fields=fields,
164187
body=payload,
165188
)
166189

190+
def _seek(self, payload: Any, pos: int, whence: int = 0) -> None:
191+
if (
192+
seek := getattr(payload, "seek", None)
193+
) is not None and not iscoroutinefunction(seek):
194+
seek(pos, whence)
195+
196+
def _compute_content_length(self, payload: Any) -> int | None:
197+
content_length = self._seek(payload, 0, 2)
198+
if content_length is not None:
199+
self._seek(payload, 0, 0)
200+
return content_length
201+
167202

168203
class HTTPRequestBindingSerializer(InterceptingSerializer):
169204
"""Delegates HTTP request bindings to binding-location-specific serializers."""
@@ -235,8 +270,12 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
235270
binding_serializer: HTTPResponseBindingSerializer
236271

237272
content_type: str | None = self._payload_codec.media_type
273+
content_length: int | None = None
274+
content_length_required = False
275+
238276
binding_matcher = ResponseBindingMatcher(schema)
239277
if (payload_member := binding_matcher.payload_member) is not None:
278+
content_length_required = RequiresLengthTrait in payload_member
240279
if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING):
241280
if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None:
242281
content_type = media_type.value
@@ -250,6 +289,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
250289
)
251290
yield binding_serializer
252291
payload = payload_serializer.payload
292+
try:
293+
content_length = len(payload)
294+
except TypeError:
295+
pass
253296
else:
254297
if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None:
255298
content_type = media_type.value
@@ -259,6 +302,7 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
259302
payload_serializer, binding_matcher
260303
)
261304
yield binding_serializer
305+
content_length = payload.tell()
262306
else:
263307
payload = BytesIO()
264308
payload_serializer = self._payload_codec.create_serializer(payload)
@@ -270,23 +314,34 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
270314
body_serializer, binding_matcher
271315
)
272316
yield binding_serializer
317+
content_length = payload.tell()
273318
else:
274319
content_type = None
320+
content_length = 0
275321
binding_serializer = HTTPResponseBindingSerializer(
276322
payload_serializer,
277323
binding_matcher,
278324
)
279325
yield binding_serializer
280326

281-
if (
282-
seek := getattr(payload, "seek", None)
283-
) is not None and not iscoroutinefunction(seek):
284-
seek(0)
285-
327+
self._seek(payload, 0)
286328
headers = binding_serializer.header_serializer.headers
287329
if content_type is not None:
288330
headers.append(("content-type", content_type))
289331

332+
if content_length is not None:
333+
headers.append(("content-length", str(content_length)))
334+
335+
fields = tuples_to_fields(headers)
336+
if content_length_required and "content-length" not in fields:
337+
content_length = self._compute_content_length(payload)
338+
if content_length is None:
339+
raise SerializationError(
340+
"This operation requires the the content length of the input "
341+
"stream, but it was not provided and was unable to be computed."
342+
)
343+
fields.set_field(Field(name="content-length", values=[str(content_length)]))
344+
290345
status = binding_serializer.response_code_serializer.response_code
291346
if status is None:
292347
if binding_matcher.response_status > 0:
@@ -300,6 +355,18 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
300355
status=status,
301356
)
302357

358+
def _seek(self, payload: Any, pos: int, whence: int = 0) -> int | None:
359+
if (
360+
seek := getattr(payload, "seek", None)
361+
) is not None and not iscoroutinefunction(seek):
362+
return seek(pos, whence)
363+
364+
def _compute_content_length(self, payload: Any) -> int | None:
365+
content_length = self._seek(payload, 0, 2)
366+
if content_length is not None:
367+
self._seek(payload, 0, 0)
368+
return content_length
369+
303370

304371
class HTTPResponseBindingSerializer(InterceptingSerializer):
305372
"""Delegates HTTP response bindings to binding-location-specific serializers."""

0 commit comments

Comments
 (0)