Skip to content

Commit 201e52b

Browse files
Only write content type when necessary
1 parent 5520941 commit 201e52b

File tree

2 files changed

+44
-24
lines changed

2 files changed

+44
-24
lines changed

packages/smithy-http/src/smithy_http/serializers.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,20 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
8282
host_prefix = self._endpoint_trait.host_prefix
8383

8484
content_type = self._payload_codec.media_type
85-
8685
binding_matcher = RequestBindingMatcher(schema)
8786
if (payload_member := binding_matcher.payload_member) is not None:
88-
if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING):
89-
content_type = (
90-
"application/octet-stream"
91-
if payload_member.shape_type is ShapeType.BLOB
92-
else "text/plain"
93-
)
87+
if payload_member.shape_type in (
88+
ShapeType.BLOB,
89+
ShapeType.STRING,
90+
ShapeType.ENUM,
91+
):
92+
if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None:
93+
content_type = media_type.value
94+
elif payload_member.shape_type is ShapeType.BLOB:
95+
content_type = "application/octet-stream"
96+
else:
97+
content_type = "text/plain"
98+
9499
payload_serializer = RawPayloadSerializer()
95100
binding_serializer = HTTPRequestBindingSerializer(
96101
payload_serializer,
@@ -113,11 +118,11 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
113118
)
114119
yield binding_serializer
115120
else:
116-
if binding_matcher.event_stream_member is not None:
117-
content_type = "application/vnd.amazon.eventstream"
118121
payload = BytesIO()
119122
payload_serializer = self._payload_codec.create_serializer(payload)
120123
if binding_matcher.should_write_body(self._omit_empty_payload):
124+
if binding_matcher.event_stream_member is not None:
125+
content_type = "application/vnd.amazon.eventstream"
121126
with payload_serializer.begin_struct(schema) as body_serializer:
122127
binding_serializer = HTTPRequestBindingSerializer(
123128
body_serializer,
@@ -127,6 +132,7 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
127132
)
128133
yield binding_serializer
129134
else:
135+
content_type = None
130136
binding_serializer = HTTPRequestBindingSerializer(
131137
payload_serializer,
132138
self._http_trait.path,
@@ -140,9 +146,9 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
140146
) is not None and not iscoroutinefunction(seek):
141147
seek(0)
142148

143-
# TODO: conditional on empty-ness and based on the protocol
144149
headers = binding_serializer.header_serializer.headers
145-
headers.append(("content-type", content_type))
150+
if content_type is not None:
151+
headers.append(("content-type", content_type))
146152

147153
self.result = _HTTPRequest(
148154
method=self._http_trait.method,
@@ -228,16 +234,25 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
228234
payload: Any
229235
binding_serializer: HTTPResponseBindingSerializer
230236

237+
content_type: str | None = self._payload_codec.media_type
231238
binding_matcher = ResponseBindingMatcher(schema)
232239
if (payload_member := binding_matcher.payload_member) is not None:
233240
if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING):
241+
if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None:
242+
content_type = media_type.value
243+
elif payload_member.shape_type is ShapeType.BLOB:
244+
content_type = "application/octet-stream"
245+
else:
246+
content_type = "text/plain"
234247
payload_serializer = RawPayloadSerializer()
235248
binding_serializer = HTTPResponseBindingSerializer(
236249
payload_serializer, binding_matcher
237250
)
238251
yield binding_serializer
239252
payload = payload_serializer.payload
240253
else:
254+
if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None:
255+
content_type = media_type.value
241256
payload = BytesIO()
242257
payload_serializer = self._payload_codec.create_serializer(payload)
243258
binding_serializer = HTTPResponseBindingSerializer(
@@ -248,12 +263,15 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
248263
payload = BytesIO()
249264
payload_serializer = self._payload_codec.create_serializer(payload)
250265
if binding_matcher.should_write_body(self._omit_empty_payload):
266+
if binding_matcher.event_stream_member is not None:
267+
content_type = "application/vnd.amazon.eventstream"
251268
with payload_serializer.begin_struct(schema) as body_serializer:
252269
binding_serializer = HTTPResponseBindingSerializer(
253270
body_serializer, binding_matcher
254271
)
255272
yield binding_serializer
256273
else:
274+
content_type = None
257275
binding_serializer = HTTPResponseBindingSerializer(
258276
payload_serializer,
259277
binding_matcher,
@@ -265,6 +283,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
265283
) is not None and not iscoroutinefunction(seek):
266284
seek(0)
267285

286+
headers = binding_serializer.header_serializer.headers
287+
if content_type is not None:
288+
headers.append(("content-type", content_type))
289+
268290
status = binding_serializer.response_code_serializer.response_code
269291
if status is None:
270292
if binding_matcher.response_status > 0:

packages/smithy-http/tests/unit/test_serializers.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
TimestampFormatTrait,
4141
Trait,
4242
)
43-
from smithy_http import Field, Fields, tuples_to_fields
43+
from smithy_http import Fields, tuples_to_fields
4444
from smithy_http.aio import HTTPResponse as _HTTPResponse
4545
from smithy_http.deserializers import HTTPResponseDeserializer
4646
from smithy_http.serializers import HTTPRequestSerializer, HTTPResponseSerializer
@@ -1552,7 +1552,9 @@ def host_cases() -> list[HTTPMessageTestCase]:
15521552
HTTPMessageTestCase(
15531553
HostLabel("foo"),
15541554
HTTPMessage(
1555-
destination=URI(host="foo.", path="/"), body=BytesIO(b'{"label":"foo"}')
1555+
destination=URI(host="foo.", path="/"),
1556+
body=BytesIO(b'{"label":"foo"}'),
1557+
fields=tuples_to_fields([("content-type", "application/json")]),
15561558
),
15571559
endpoint_trait=EndpointTrait({"hostPrefix": "{label}."}),
15581560
),
@@ -1569,7 +1571,9 @@ def payload_cases() -> list[HTTPMessageTestCase]:
15691571
HTTPMessageTestCase(
15701572
HTTPImplicitPayload(header="foo", payload_member="bar"),
15711573
HTTPMessage(
1572-
fields=tuples_to_fields([("header", "foo")]),
1574+
fields=tuples_to_fields(
1575+
[("header", "foo"), ("content-type", "application/json")]
1576+
),
15731577
body=BytesIO(b'{"payload_member":"bar"}'),
15741578
),
15751579
),
@@ -1588,7 +1592,10 @@ def payload_cases() -> list[HTTPMessageTestCase]:
15881592
),
15891593
HTTPMessageTestCase(
15901594
HTTPStructuredPayload(payload=HTTPStringPayload(payload="foo")),
1591-
HTTPMessage(body=BytesIO(b'{"payload":"foo"}')),
1595+
HTTPMessage(
1596+
body=BytesIO(b'{"payload":"foo"}'),
1597+
fields=tuples_to_fields([("content-type", "application/json")]),
1598+
),
15921599
),
15931600
]
15941601

@@ -1615,8 +1622,6 @@ def async_streaming_payload_cases() -> list[HTTPMessageTestCase]:
16151622
+ async_streaming_payload_cases()
16161623
)
16171624

1618-
CONTENT_TYPE_FIELD = Field(name="content-type", values=["application/json"])
1619-
16201625

16211626
@pytest.mark.parametrize("case", REQUEST_SER_CASES)
16221627
async def test_serialize_http_request(case: HTTPMessageTestCase) -> None:
@@ -1636,10 +1641,6 @@ async def test_serialize_http_request(case: HTTPMessageTestCase) -> None:
16361641
actual_query = actual.destination.query or ""
16371642
expected_query = case.request.destination.query or ""
16381643
assert actual_query == expected_query
1639-
# set the content-type field here, otherwise cases would have to duplicate it everywhere,
1640-
# but if the field is already set in the case, don't override it
1641-
if expected.fields.get(CONTENT_TYPE_FIELD.name) is None:
1642-
expected.fields.set_field(CONTENT_TYPE_FIELD)
16431644
assert actual.fields == expected.fields
16441645

16451646
if case.request.body:
@@ -1678,9 +1679,6 @@ async def test_serialize_http_response(case: HTTPMessageTestCase) -> None:
16781679
expected = case.request
16791680

16801681
assert actual is not None
1681-
# Remove content-type from expected, we're re-using the request cases for brevity
1682-
if expected.fields.get(CONTENT_TYPE_FIELD.name) is not None:
1683-
del expected.fields[CONTENT_TYPE_FIELD.name]
16841682
assert actual.fields == expected.fields
16851683
assert actual.status == expected.status
16861684

0 commit comments

Comments
 (0)