Skip to content

Fix media and content type #509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions packages/smithy-core/src/smithy_core/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,12 @@ def __post_init__(self):
assert self.document_value is None


@dataclass(init=False, frozen=True)
class RequiresLengthTrait(Trait, id=ShapeID("smithy.api#requiresLength")):
def __post_init__(self):
assert self.document_value is None


@dataclass(init=False, frozen=True)
class UnitTypeTrait(Trait, id=ShapeID("smithy.api#UnitTypeTrait")):
def __post_init__(self):
Expand Down
131 changes: 110 additions & 21 deletions packages/smithy-http/src/smithy_http/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from smithy_core import URI
from smithy_core.codecs import Codec
from smithy_core.exceptions import SerializationError
from smithy_core.schemas import Schema
from smithy_core.serializers import (
InterceptingSerializer,
Expand All @@ -24,12 +25,13 @@
HTTPQueryTrait,
HTTPTrait,
MediaTypeTrait,
RequiresLengthTrait,
TimestampFormatTrait,
)
from smithy_core.types import PathPattern, TimestampFormat
from smithy_core.utils import serialize_float

from . import tuples_to_fields
from . import Field, tuples_to_fields
from .aio import HTTPRequest as _HTTPRequest
from .aio import HTTPResponse as _HTTPResponse
from .aio.interfaces import HTTPRequest, HTTPResponse
Expand All @@ -43,6 +45,7 @@
__all__ = ["HTTPRequestSerializer", "HTTPResponseSerializer"]


# TODO: refactor this to share code with response serializer
class HTTPRequestSerializer(SpecificShapeSerializer):
"""Binds a serializable shape to an HTTP request.

Expand Down Expand Up @@ -82,15 +85,24 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
host_prefix = self._endpoint_trait.host_prefix

content_type = self._payload_codec.media_type
content_length: int | None = None
content_length_required = False

binding_matcher = RequestBindingMatcher(schema)
if (payload_member := binding_matcher.payload_member) is not None:
if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING):
content_type = (
"application/octet-stream"
if payload_member.shape_type is ShapeType.BLOB
else "text/plain"
)
content_length_required = RequiresLengthTrait in payload_member
if payload_member.shape_type in (
ShapeType.BLOB,
ShapeType.STRING,
ShapeType.ENUM,
):
if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None:
content_type = media_type.value
elif payload_member.shape_type is ShapeType.BLOB:
content_type = "application/octet-stream"
else:
content_type = "text/plain"

payload_serializer = RawPayloadSerializer()
binding_serializer = HTTPRequestBindingSerializer(
payload_serializer,
Expand All @@ -100,6 +112,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
)
yield binding_serializer
payload = payload_serializer.payload
try:
content_length = len(payload)
except TypeError:
pass
else:
if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None:
content_type = media_type.value
Expand All @@ -112,12 +128,13 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
binding_matcher,
)
yield binding_serializer
content_length = payload.tell()
else:
if binding_matcher.event_stream_member is not None:
content_type = "application/vnd.amazon.eventstream"
payload = BytesIO()
payload_serializer = self._payload_codec.create_serializer(payload)
if binding_matcher.should_write_body(self._omit_empty_payload):
if binding_matcher.event_stream_member is not None:
content_type = "application/vnd.amazon.eventstream"
with payload_serializer.begin_struct(schema) as body_serializer:
binding_serializer = HTTPRequestBindingSerializer(
body_serializer,
Expand All @@ -126,7 +143,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
binding_matcher,
)
yield binding_serializer
content_length = payload.tell()
else:
content_type = None
content_length = 0
binding_serializer = HTTPRequestBindingSerializer(
payload_serializer,
self._http_trait.path,
Expand All @@ -135,14 +155,23 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
)
yield binding_serializer

if (
seek := getattr(payload, "seek", None)
) is not None and not iscoroutinefunction(seek):
seek(0)

# TODO: conditional on empty-ness and based on the protocol
self._seek(payload, 0)
headers = binding_serializer.header_serializer.headers
headers.append(("content-type", content_type))
if content_type is not None:
headers.append(("content-type", content_type))

if content_length is not None:
headers.append(("content-length", str(content_length)))

fields = tuples_to_fields(headers)
if content_length_required and "content-length" not in fields:
content_length = self._compute_content_length(payload)
if content_length is None:
raise SerializationError(
"This operation requires the the content length of the input "
"stream, but it was not provided and was unable to be computed."
)
fields.set_field(Field(name="content-length", values=[str(content_length)]))

self.result = _HTTPRequest(
method=self._http_trait.method,
Expand All @@ -154,10 +183,22 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
prefix=self._http_trait.query or "",
),
),
fields=tuples_to_fields(headers),
fields=fields,
body=payload,
)

def _seek(self, payload: Any, pos: int, whence: int = 0) -> None:
if (
seek := getattr(payload, "seek", None)
) is not None and not iscoroutinefunction(seek):
seek(pos, whence)

def _compute_content_length(self, payload: Any) -> int | None:
content_length = self._seek(payload, 0, 2)
if content_length is not None:
self._seek(payload, 0, 0)
return content_length


class HTTPRequestBindingSerializer(InterceptingSerializer):
"""Delegates HTTP request bindings to binding-location-specific serializers."""
Expand Down Expand Up @@ -228,42 +269,78 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
payload: Any
binding_serializer: HTTPResponseBindingSerializer

content_type: str | None = self._payload_codec.media_type
content_length: int | None = None
content_length_required = False

binding_matcher = ResponseBindingMatcher(schema)
if (payload_member := binding_matcher.payload_member) is not None:
content_length_required = RequiresLengthTrait in payload_member
if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING):
if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None:
content_type = media_type.value
elif payload_member.shape_type is ShapeType.BLOB:
content_type = "application/octet-stream"
else:
content_type = "text/plain"
payload_serializer = RawPayloadSerializer()
binding_serializer = HTTPResponseBindingSerializer(
payload_serializer, binding_matcher
)
yield binding_serializer
payload = payload_serializer.payload
try:
content_length = len(payload)
except TypeError:
pass
else:
if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None:
content_type = media_type.value
payload = BytesIO()
payload_serializer = self._payload_codec.create_serializer(payload)
binding_serializer = HTTPResponseBindingSerializer(
payload_serializer, binding_matcher
)
yield binding_serializer
content_length = payload.tell()
else:
payload = BytesIO()
payload_serializer = self._payload_codec.create_serializer(payload)
if binding_matcher.should_write_body(self._omit_empty_payload):
if binding_matcher.event_stream_member is not None:
content_type = "application/vnd.amazon.eventstream"
with payload_serializer.begin_struct(schema) as body_serializer:
binding_serializer = HTTPResponseBindingSerializer(
body_serializer, binding_matcher
)
yield binding_serializer
content_length = payload.tell()
else:
content_type = None
content_length = 0
binding_serializer = HTTPResponseBindingSerializer(
payload_serializer,
binding_matcher,
)
yield binding_serializer

if (
seek := getattr(payload, "seek", None)
) is not None and not iscoroutinefunction(seek):
seek(0)
self._seek(payload, 0)
headers = binding_serializer.header_serializer.headers
if content_type is not None:
headers.append(("content-type", content_type))

if content_length is not None:
headers.append(("content-length", str(content_length)))

fields = tuples_to_fields(headers)
if content_length_required and "content-length" not in fields:
content_length = self._compute_content_length(payload)
if content_length is None:
raise SerializationError(
"This operation requires the the content length of the input "
"stream, but it was not provided and was unable to be computed."
)
fields.set_field(Field(name="content-length", values=[str(content_length)]))

status = binding_serializer.response_code_serializer.response_code
if status is None:
Expand All @@ -278,6 +355,18 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
status=status,
)

def _seek(self, payload: Any, pos: int, whence: int = 0) -> int | None:
if (
seek := getattr(payload, "seek", None)
) is not None and not iscoroutinefunction(seek):
return seek(pos, whence)

def _compute_content_length(self, payload: Any) -> int | None:
content_length = self._seek(payload, 0, 2)
if content_length is not None:
self._seek(payload, 0, 0)
return content_length


class HTTPResponseBindingSerializer(InterceptingSerializer):
"""Delegates HTTP response bindings to binding-location-specific serializers."""
Expand Down
Loading