From c1f7bdd53e3866c47580e7cb68e6821d72d793e7 Mon Sep 17 00:00:00 2001 From: Rafael Winterhalter Date: Wed, 20 Apr 2022 22:28:24 +0200 Subject: [PATCH] Add interoperability with InputStream/OutputStream with WebClient and a JDK client connector for RestTemplate. This change simplifies the use of syncronous HTTP APIs in Spring by adding support for InputStream/OutputStream for WebClient. While WebClient advertises that it offers an easy API for synchronous and asyncronous APIs, it is currently not possible to integrate against libraries that only work with InputStream or OutputStream. Typically, this leads to intermediate manifestation of inputs and outputs as byte arrays what causes additional overhead. This change therefore suggests a codec where the incoming ByteBuffers are piped to an InputStream and body extractors and inserters that work with streams. While this still requires blocking a thread, it avoids the intermediate buffering. It also simplifies code as the intermediate buffering does not need to be implemented. Additionally, this change suggests a ClientHttpRequestFactory for RestTemplate that uses the JDK client directly or any new connector via a WebClient-based request factory. As the RestTemplate is no longer advertised as to be deprecated, and since the JDK client offers APIs for synchronous use with streams, offering this support seems sensible. This also allows for using RestTemplate and WebClient based on the same client which reduces the need to, for example, setup proxy configurations for two different HTTP client implementations. --- .../core/codec/InputStreamDecoder.java | 270 +++++++++++++++ .../core/codec/InputStreamDecoderTests.java | 151 ++++++++ .../AbstractBufferingClientHttpRequest.java | 2 +- .../http/client/JdkClientHttpRequest.java | 123 +++++++ .../client/JdkClientHttpRequestFactory.java | 117 +++++++ .../http/client/JdkClientHttpResponse.java | 93 +++++ .../client/JdkClientStreamingHttpRequest.java | 216 ++++++++++++ .../http/codec/support/BaseDefaultCodecs.java | 2 + .../web/client/RestTemplate.java | 2 + .../JdkClientHttpRequestFactoryTests.java | 26 ++ ...amingJdkClientHttpRequestFactoryTests.java | 28 ++ .../client/RestTemplateIntegrationTests.java | 8 +- .../web/reactive/function/BodyExtractors.java | 322 ++++++++++++++++++ .../web/reactive/function/BodyInserters.java | 146 ++++++++ .../function/WebClientHttpRequest.java | 82 +++++ .../function/WebClientHttpRequestFactory.java | 94 +++++ .../function/WebClientHttpResponse.java | 85 +++++ .../WebClientStreamingHttpRequest.java | 99 ++++++ .../function/BodyExtractorsTests.java | 41 +++ .../reactive/function/BodyInsertersTests.java | 44 +++ 20 files changed, 1945 insertions(+), 6 deletions(-) create mode 100644 spring-core/src/main/java/org/springframework/core/codec/InputStreamDecoder.java create mode 100644 spring-core/src/test/java/org/springframework/core/codec/InputStreamDecoderTests.java create mode 100644 spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java create mode 100644 spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequestFactory.java create mode 100644 spring-web/src/main/java/org/springframework/http/client/JdkClientHttpResponse.java create mode 100644 spring-web/src/main/java/org/springframework/http/client/JdkClientStreamingHttpRequest.java create mode 100644 spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java create mode 100644 spring-web/src/test/java/org/springframework/http/client/StreamingJdkClientHttpRequestFactoryTests.java create mode 100644 spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpRequest.java create mode 100644 spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpRequestFactory.java create mode 100644 spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpResponse.java create mode 100644 spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientStreamingHttpRequest.java diff --git a/spring-core/src/main/java/org/springframework/core/codec/InputStreamDecoder.java b/spring-core/src/main/java/org/springframework/core/codec/InputStreamDecoder.java new file mode 100644 index 000000000000..224a499af022 --- /dev/null +++ b/spring-core/src/main/java/org/springframework/core/codec/InputStreamDecoder.java @@ -0,0 +1,270 @@ +package org.springframework.core.codec; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.springframework.core.ResolvableType; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.lang.Nullable; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +/** + * Decoder that translates data buffers to an {@link InputStream}. + */ +public class InputStreamDecoder extends AbstractDataBufferDecoder { + + public static final String FAIL_FAST = InputStreamDecoder.class.getName() + ".FAIL_FAST"; + + public InputStreamDecoder() { + super(MimeTypeUtils.ALL); + } + + @Override + public boolean canDecode(ResolvableType elementType, @Nullable MimeType mimeType) { + return (elementType.resolve() == InputStream.class && super.canDecode(elementType, mimeType)); + } + + @Override + public InputStream decode(DataBuffer dataBuffer, ResolvableType elementType, + @Nullable MimeType mimeType, @Nullable Map hints) { + + if (logger.isDebugEnabled()) { + logger.debug(Hints.getLogPrefix(hints) + "Reading " + dataBuffer.readableByteCount() + " bytes"); + } + return dataBuffer.asInputStream(true); + } + + @Override + public Mono decodeToMono(Publisher input, ResolvableType elementType, + @Nullable MimeType mimeType, @Nullable Map hints) { + + boolean failFast = hints == null || Boolean.TRUE.equals(hints.getOrDefault(FAIL_FAST, Boolean.TRUE)); + FlowBufferInputStream inputStream = new FlowBufferInputStream(getMaxInMemorySize(), failFast); + Flux.from(input).subscribe(inputStream); + + return Mono.just(inputStream); + } + + static class FlowBufferInputStream extends InputStream implements Subscriber { + + private static final Object END = new Object(); + + private final AtomicBoolean closed = new AtomicBoolean(); + + private final BlockingQueue backlog; + + private final int maximumMemorySize; + + private final boolean failFast; + + private final AtomicInteger buffered = new AtomicInteger(); + + @Nullable + private InputStreamWithSize current = new InputStreamWithSize(0, InputStream.nullInputStream()); + + @Nullable + private Subscription subscription; + + FlowBufferInputStream(int maximumMemorySize, boolean failFast) { + this.backlog = new LinkedBlockingDeque<>(); + this.maximumMemorySize = maximumMemorySize; + this.failFast = failFast; + } + + @Override + public void onSubscribe(Subscription subscription) { + this.subscription = subscription; + if (this.closed.get()) { + subscription.cancel(); + } else { + subscription.request(1); + } + } + + @Override + public void onNext(DataBuffer buffer) { + if (this.closed.get()) { + DataBufferUtils.release(buffer); + return; + } + int readableByteCount = buffer.readableByteCount(); + int current = this.buffered.addAndGet(readableByteCount); + if (current < this.maximumMemorySize) { + this.subscription.request(1); + } + InputStream stream = buffer.asInputStream(true); + this.backlog.add(new InputStreamWithSize(readableByteCount, stream)); + if (this.closed.get()) { + DataBufferUtils.release(buffer); + } + } + + @Override + public void onError(Throwable throwable) { + if (failFast) { + Object next; + while ((next = this.backlog.poll()) != null) { + if (next instanceof InputStreamWithSize) { + try { + ((InputStreamWithSize) next).inputStream.close(); + } catch (Throwable t) { + throwable.addSuppressed(t); + } + } + } + } + this.backlog.add(throwable); + } + + @Override + public void onComplete() { + this.backlog.add(END); + } + + private boolean forward() throws IOException { + this.current.inputStream.close(); + try { + Object next = this.backlog.take(); + if (next == END) { + this.current = null; + return true; + } else if (next instanceof RuntimeException) { + close(); + throw (RuntimeException) next; + } else if (next instanceof IOException) { + close(); + throw (IOException) next; + } else if (next instanceof Throwable) { + close(); + throw new IllegalStateException((Throwable) next); + } else { + int buffer = buffered.addAndGet(-this.current.size); + if (buffer < this.maximumMemorySize) { + this.subscription.request(1); + } + this.current = (InputStreamWithSize) next; + return false; + } + } catch (InterruptedException e) { + throw new IllegalStateException(e); + } + } + + @Override + public int read() throws IOException { + if (this.closed.get()) { + throw new IOException("closed"); + } else if (this.current == null) { + return -1; + } + int read; + while ((read = this.current.inputStream.read()) == -1) { + if (forward()) { + return -1; + } + } + return read; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + Objects.checkFromIndexSize(off, len, b.length); + if (this.closed.get()) { + throw new IOException("closed"); + } else if (this.current == null) { + return -1; + } else if (len == 0) { + return 0; + } + int sum = 0; + do { + int read = this.current.inputStream.read(b, off + sum, len - sum); + if (read == -1) { + if (sum > 0 && this.backlog.isEmpty()) { + return sum; + } else if (forward()) { + return sum == 0 ? -1 : sum; + } + } else { + sum += read; + } + } while (sum < len); + return sum; + } + + @Override + public int available() throws IOException { + if (this.closed.get()) { + throw new IOException("closed"); + } else if (this.current == null) { + return 0; + } + int available = this.current.inputStream.available(); + for (Object value : this.backlog) { + if (value instanceof InputStreamWithSize) { + available += ((InputStreamWithSize) value).inputStream.available(); + } else { + break; + } + } + return available; + } + + @Override + public void close() throws IOException { + if (this.closed.compareAndSet(false, true)) { + if (this.subscription != null) { + this.subscription.cancel(); + } + IOException exception = null; + if (this.current != null) { + try { + this.current.inputStream.close(); + } catch (IOException e) { + exception = e; + } + } + for (Object value : this.backlog) { + if (value instanceof InputStreamWithSize) { + try { + ((InputStreamWithSize) value).inputStream.close(); + } catch (IOException e) { + if (exception == null) { + exception = e; + } else { + exception.addSuppressed(e); + } + } + } + } + if (exception != null) { + throw exception; + } + } + } + } + + static class InputStreamWithSize { + + final int size; + + final InputStream inputStream; + + InputStreamWithSize(int size, InputStream inputStream) { + this.size = size; + this.inputStream = inputStream; + } + } +} \ No newline at end of file diff --git a/spring-core/src/test/java/org/springframework/core/codec/InputStreamDecoderTests.java b/spring-core/src/test/java/org/springframework/core/codec/InputStreamDecoderTests.java new file mode 100644 index 000000000000..4b1f842abbe4 --- /dev/null +++ b/spring-core/src/test/java/org/springframework/core/codec/InputStreamDecoderTests.java @@ -0,0 +1,151 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.core.codec; + +import org.junit.jupiter.api.Test; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.springframework.core.ResolvableType; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.testfixture.codec.AbstractDecoderTests; +import org.springframework.lang.Nullable; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; + +import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.Map; +import java.util.function.Consumer; + +import static org.assertj.core.api.Assertions.*; + +/** + * @author Vladislav Kisel + */ +class InputStreamDecoderTests extends AbstractDecoderTests { + + private final byte[] fooBytes = "foo".getBytes(StandardCharsets.UTF_8); + + private final byte[] barBytes = "bar".getBytes(StandardCharsets.UTF_8); + + + InputStreamDecoderTests() { + super(new InputStreamDecoder()); + } + + @Override + @Test + public void canDecode() { + assertThat(this.decoder.canDecode(ResolvableType.forClass(InputStream.class), + MimeTypeUtils.TEXT_PLAIN)).isTrue(); + assertThat(this.decoder.canDecode(ResolvableType.forClass(Integer.class), + MimeTypeUtils.TEXT_PLAIN)).isFalse(); + assertThat(this.decoder.canDecode(ResolvableType.forClass(InputStream.class), + MimeTypeUtils.APPLICATION_JSON)).isTrue(); + } + + @Override + @Test + public void decode() { + Flux input = Flux.just( + this.bufferFactory.wrap(this.fooBytes), + this.bufferFactory.wrap(this.barBytes)); + + testDecodeAll(input, InputStream.class, step -> step + .consumeNextWith(expectInputStream(this.fooBytes)) + .consumeNextWith(expectInputStream(this.barBytes)) + .verifyComplete()); + } + + @Override + @Test + public void decodeToMono() { + Flux input = Flux.concat( + dataBuffer(this.fooBytes), + dataBuffer(this.barBytes)); + + byte[] expected = new byte[this.fooBytes.length + this.barBytes.length]; + System.arraycopy(this.fooBytes, 0, expected, 0, this.fooBytes.length); + System.arraycopy(this.barBytes, 0, expected, this.fooBytes.length, this.barBytes.length); + + testDecodeToMonoAll(input, InputStream.class, step -> step + .consumeNextWith(expectInputStream(expected)) + .verifyComplete()); + testDecodeToMonoErrorFailLast(input, expected); + } + + @Override + protected void testDecodeToMonoError(Publisher input, ResolvableType outputType, + @Nullable MimeType mimeType, @Nullable Map hints) { + + input = Flux.from(input).concatWith(Flux.error(new InputException())); + try (InputStream result = this.decoder.decodeToMono(input, outputType, mimeType, hints).block()) { + assertThatThrownBy(() -> result.read()).isInstanceOf(InputException.class); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private void testDecodeToMonoErrorFailLast(Publisher input, byte[] expected) { + input = Flux.concatDelayError(Flux.from(input), Flux.error(new InputException())); + try (InputStream result = this.decoder.decodeToMono(input, + ResolvableType.forType(InputStream.class), + null, + Collections.singletonMap(InputStreamDecoder.FAIL_FAST, false)).block()) { + byte[] actual = new byte[expected.length]; + assertThat(result.read(actual)).isEqualTo(expected.length); + assertThat(actual).isEqualTo(expected); + assertThatThrownBy(() -> result.read()).isInstanceOf(InputException.class); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + protected void testDecodeToMonoCancel(Publisher input, ResolvableType outputType, + @Nullable MimeType mimeType, @Nullable Map hints) { } + + @Override + protected void testDecodeToMonoEmpty(ResolvableType outputType, @Nullable MimeType mimeType, + @Nullable Map hints) { + + try (InputStream result = this.decoder.decodeToMono(Flux.empty(), outputType, mimeType, hints).block()) { + assertThat(result.read()).isEqualTo(-1); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private Consumer expectInputStream(byte[] expected) { + return actual -> { + try (actual) { + byte[] actualBytes = actual.readAllBytes(); + assertThat(actualBytes).isEqualTo(expected); + } catch (IOException ignored) { + } + }; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/AbstractBufferingClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/AbstractBufferingClientHttpRequest.java index abedb2c051c6..9cba17d3303a 100644 --- a/spring-web/src/main/java/org/springframework/http/client/AbstractBufferingClientHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/client/AbstractBufferingClientHttpRequest.java @@ -29,7 +29,7 @@ * @author Arjen Poutsma * @since 3.0.6 */ -abstract class AbstractBufferingClientHttpRequest extends AbstractClientHttpRequest { +public abstract class AbstractBufferingClientHttpRequest extends AbstractClientHttpRequest { private ByteArrayOutputStream bufferedOutput = new ByteArrayOutputStream(1024); diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java new file mode 100644 index 000000000000..ef08f074c4b7 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java @@ -0,0 +1,123 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; + +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; + +/** + * {@link ClientHttpRequest} implementation based on + * JDK HTTP client. + * + *

Created via the {@link JdkClientHttpRequestFactory}. + */ +final class JdkClientHttpRequest extends AbstractBufferingClientHttpRequest { + + private final HttpClient httpClient; + + private final HttpMethod method; + + private final URI uri; + + private final boolean expectContinue; + + @Nullable + private final Duration requestTimeout; + + JdkClientHttpRequest(HttpClient client, HttpMethod method, URI uri, + boolean expectContinue, @Nullable Duration requestTimeout) { + this.httpClient = client; + this.method = method; + this.uri = uri; + this.expectContinue = expectContinue; + this.requestTimeout = requestTimeout; + } + + + @Override + public HttpMethod getMethod() { + return this.method; + } + + @Override + @Deprecated + public String getMethodValue() { + return this.method.name(); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + protected ClientHttpResponse executeInternal(HttpHeaders headers, byte[] bufferedOutput) throws IOException { + HttpRequest.Builder builder = HttpRequest.newBuilder(this.uri); + + addHeaders(builder, headers); + + builder.method(this.method.name(), bufferedOutput.length == 0 + ? HttpRequest.BodyPublishers.noBody() + : HttpRequest.BodyPublishers.ofByteArray(bufferedOutput)); + + if (expectContinue) { + builder.expectContinue(true); + } + if (requestTimeout != null) { + builder.timeout(requestTimeout); + } + + HttpResponse response; + try { + response = this.httpClient.send(builder.build(), HttpResponse.BodyHandlers.ofInputStream()); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new JdkClientHttpResponse(response); + } + + /** + * Add the given headers to the given HTTP request. + * @param builder the request builder to add the headers to + * @param headers the headers to add + */ + static void addHeaders(HttpRequest.Builder builder, HttpHeaders headers) { + headers.forEach((headerName, headerValues) -> { + if (HttpHeaders.COOKIE.equalsIgnoreCase(headerName)) { // RFC 6265 + String headerValue = StringUtils.collectionToDelimitedString(headerValues, "; "); + builder.header(headerName, headerValue); + } + else if (!HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(headerName) && + !HttpHeaders.TRANSFER_ENCODING.equalsIgnoreCase(headerName)) { + for (String headerValue : headerValues) { + builder.header(headerName, headerValue); + } + } + }); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequestFactory.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequestFactory.java new file mode 100644 index 000000000000..7b8fc4b4ac59 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequestFactory.java @@ -0,0 +1,117 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.time.Duration; + +/** + * {@link org.springframework.http.client.ClientHttpRequestFactory} implementation that + * uses the Java {@link HttpClient}. + */ +public class JdkClientHttpRequestFactory implements ClientHttpRequestFactory { + + private HttpClient httpClient; + + private boolean expectContinue; + + @Nullable + private Duration requestTimeout; + + private boolean bufferRequestBody = true; + + + /** + * Create a new instance of the {@code JdkClientHttpRequestFactory} + * with a default {@link HttpClient}. + */ + public JdkClientHttpRequestFactory() { + this.httpClient = HttpClient.newHttpClient(); + } + + /** + * Create a new instance of the {@code JdkClientHttpRequestFactory} + * with the given {@link HttpClient} instance. + * @param httpClient the HttpClient instance to use for this request factory + */ + public JdkClientHttpRequestFactory(HttpClient httpClient) { + this.httpClient = httpClient; + } + + /** + * Set the {@code HttpClient} used for + * {@linkplain #createRequest(URI, HttpMethod) synchronous execution}. + */ + public void setHttpClient(HttpClient httpClient) { + Assert.notNull(httpClient, "HttpClient must not be null"); + this.httpClient = httpClient; + } + + /** + * Return the {@code HttpClient} used for + * {@linkplain #createRequest(URI, HttpMethod) synchronous execution}. + */ + public HttpClient getHttpClient() { + return this.httpClient; + } + + /** + * If {@code true}, requests the server to acknowledge the request before sending the body. + * @param expectContinue {@code} if the server is requested to acknowledge the request + * @see HttpRequest#expectContinue() + */ + public void setExpectContinue(boolean expectContinue) { + this.expectContinue = expectContinue; + } + + /** + * Set the request timeout for a request. A {code null} of 0 specifies an infinite timeout. + * @param requestTimeout the timeout value or {@code null} to disable the timeout + * @see HttpRequest#timeout() + */ + public void setRequestTimeout(@Nullable Duration requestTimeout) { + this.requestTimeout = requestTimeout; + } + + /** + * Indicates whether this request factory should buffer the request body internally. + *

Default is {@code true}. When sending large amounts of data via POST or PUT, it is + * recommended to change this property to {@code false}, so as not to run out of memory. + */ + public void setBufferRequestBody(boolean bufferRequestBody) { + this.bufferRequestBody = bufferRequestBody; + } + + @Override + public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException { + HttpClient client = getHttpClient(); + + if (this.bufferRequestBody) { + return new JdkClientHttpRequest(client, httpMethod, uri, expectContinue, requestTimeout); + } + else { + return new JdkClientStreamingHttpRequest(client, httpMethod, uri, expectContinue, requestTimeout); + } + } +} diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpResponse.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpResponse.java new file mode 100644 index 000000000000..bd21be8e3b6e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpResponse.java @@ -0,0 +1,93 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatusCode; +import org.springframework.lang.Nullable; + +import java.io.IOException; +import java.io.InputStream; +import java.net.http.HttpResponse; + +/** + * {@link ClientHttpResponse} implementation based on + * JDK HTTP client. + * + *

Created via the {@link JdkClientHttpRequest}. + */ +final class JdkClientHttpResponse implements ClientHttpResponse { + + private final HttpResponse httpResponse; + + @Nullable + private HttpHeaders headers; + + + JdkClientHttpResponse(HttpResponse httpResponse) { + this.httpResponse = httpResponse; + } + + + @Override + public HttpStatusCode getStatusCode() throws IOException { + return HttpStatusCode.valueOf(this.httpResponse.statusCode()); + } + + @Override + @Deprecated + public int getRawStatusCode() throws IOException { + return this.httpResponse.statusCode(); + } + + @Override + public String getStatusText() throws IOException { + return ""; + } + + @Override + public HttpHeaders getHeaders() { + if (this.headers == null) { + this.headers = new HttpHeaders(); + this.httpResponse.headers().map().forEach((key, values) -> this.headers.addAll(key, values)); + } + return this.headers; + } + + @Override + public InputStream getBody() throws IOException { + return this.httpResponse.body(); + } + + @Override + public void close() { + // Release underlying connection back to the connection manager + try { + try { + // Attempt to keep connection alive by consuming its remaining content + this.httpResponse.body().readAllBytes(); + } + finally { + this.httpResponse.body().close(); + } + } + catch (IOException ex) { + // Ignore exception on close... + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientStreamingHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientStreamingHttpRequest.java new file mode 100644 index 000000000000..0a26b743440e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientStreamingHttpRequest.java @@ -0,0 +1,216 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.StreamingHttpOutputMessage; +import org.springframework.lang.Nullable; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicReference; + +/** + * {@link ClientHttpRequest} implementation based on + * JDK HTTP client in streaming mode. + * + *

Created via the {@link JdkClientHttpRequestFactory}. + */ +final class JdkClientStreamingHttpRequest extends AbstractClientHttpRequest + implements StreamingHttpOutputMessage { + + private final HttpClient httpClient; + + private final HttpMethod method; + + private final URI uri; + + private final boolean expectContinue; + + @Nullable + private final Duration requestTimeout; + + @Nullable + private Body body; + + JdkClientStreamingHttpRequest(HttpClient client, HttpMethod method, URI uri, + boolean expectContinue, @Nullable Duration requestTimeout) { + this.httpClient = client; + this.method = method; + this.uri = uri; + this.expectContinue = expectContinue; + this.requestTimeout = requestTimeout; + } + + @Override + public HttpMethod getMethod() { + return this.method; + } + + @Override + @Deprecated + public String getMethodValue() { + return this.method.name(); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + public void setBody(Body body) { + assertNotExecuted(); + this.body = body; + } + + @Override + protected OutputStream getBodyInternal(HttpHeaders headers) throws IOException { + throw new UnsupportedOperationException("getBody not supported"); + } + + @Override + protected ClientHttpResponse executeInternal(HttpHeaders headers) throws IOException { + HttpRequest.Builder builder = HttpRequest.newBuilder(this.uri); + + JdkClientHttpRequest.addHeaders(builder, headers); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference reference; + if (this.body != null) { + reference = new AtomicReference<>(); + builder.method(this.method.name(), HttpRequest.BodyPublishers.fromPublisher(subscriber -> { + SubscriptionOutputStream outputStream = new SubscriptionOutputStream(subscriber); + reference.set(outputStream); + latch.countDown(); + try { + subscriber.onSubscribe(outputStream); + } catch (Throwable t) { + outputStream.closed = true; + throw t; + } + })); + } else { + reference = null; + builder.method(this.method.name(), HttpRequest.BodyPublishers.noBody()); + } + + if (expectContinue) { + builder.expectContinue(true); + } + if (requestTimeout != null) { + builder.timeout(requestTimeout); + } + + HttpResponse response; + try { + if (this.body != null) { + CompletableFuture> future = this.httpClient.sendAsync(builder.build(), HttpResponse.BodyHandlers.ofInputStream()); + latch.await(); + SubscriptionOutputStream outputStream = reference.get(); + try (outputStream) { + this.body.writeTo(outputStream); + } catch (Throwable t) { + outputStream.cancel(); + outputStream.subscriber.onError(t); + } + response = future.join(); + } else { + response = this.httpClient.send(builder.build(), HttpResponse.BodyHandlers.ofInputStream());; + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new JdkClientHttpResponse(response); + } + + static class SubscriptionOutputStream extends OutputStream implements Flow.Subscription { + + private final Flow.Subscriber subscriber; + + private final Semaphore semaphore = new Semaphore(0); + + private volatile boolean closed; + + SubscriptionOutputStream(Flow.Subscriber subscriber) { + this.subscriber = subscriber; + } + + @Override + public void write(byte[] b) throws IOException {; + if (acquire()) { + subscriber.onNext(ByteBuffer.wrap(b)); + } + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + if (acquire()) { + subscriber.onNext(ByteBuffer.wrap(b, off, len)); + } + } + + @Override + public void write(int b) throws IOException { + if (acquire()) { + subscriber.onNext(ByteBuffer.wrap(new byte[] {(byte) b})); + } + } + + @Override + public void close() throws IOException { + if (!closed) { + closed = true; + subscriber.onComplete(); + } + } + + private boolean acquire() throws IOException { + if (closed) { + throw new IOException("closed"); + } + try { + semaphore.acquire(); + return true; + } catch (InterruptedException e) { + closed = true; + subscriber.onError(e); + return false; + } + } + + @Override + public void request(long n) { + semaphore.release((int) n); + } + + @Override + public void cancel() { + closed = true; + semaphore.release(1); + } + } +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java b/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java index a999661e733e..7c828a3c98b4 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java +++ b/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java @@ -44,6 +44,7 @@ import org.springframework.http.codec.FormHttpMessageWriter; import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.core.codec.InputStreamDecoder; import org.springframework.http.codec.ResourceHttpMessageReader; import org.springframework.http.codec.ResourceHttpMessageWriter; import org.springframework.http.codec.ServerSentEventHttpMessageReader; @@ -341,6 +342,7 @@ protected void initTypedReaders() { addCodec(this.typedReaders, new DecoderHttpMessageReader<>(new ByteArrayDecoder())); addCodec(this.typedReaders, new DecoderHttpMessageReader<>(new ByteBufferDecoder())); addCodec(this.typedReaders, new DecoderHttpMessageReader<>(new DataBufferDecoder())); + addCodec(this.typedReaders, new DecoderHttpMessageReader<>(new InputStreamDecoder())); if (nettyByteBufPresent) { addCodec(this.typedReaders, new DecoderHttpMessageReader<>(new NettyByteBufDecoder())); } diff --git a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java index 31c7c4224cba..a2f70424b4e1 100644 --- a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java +++ b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java @@ -40,6 +40,7 @@ import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpResponse; +import org.springframework.http.client.JdkClientHttpRequestFactory; import org.springframework.http.client.support.InterceptingHttpAccessor; import org.springframework.http.converter.ByteArrayHttpMessageConverter; import org.springframework.http.converter.GenericHttpMessageConverter; @@ -203,6 +204,7 @@ else if (kotlinSerializationJsonPresent) { * @param requestFactory the HTTP request factory to use * @see org.springframework.http.client.SimpleClientHttpRequestFactory * @see org.springframework.http.client.HttpComponentsClientHttpRequestFactory + * @see JdkClientHttpRequestFactory */ public RestTemplate(ClientHttpRequestFactory requestFactory) { this(); diff --git a/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java new file mode 100644 index 000000000000..6c58ed53d17c --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java @@ -0,0 +1,26 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +public class JdkClientHttpRequestFactoryTests extends AbstractHttpRequestFactoryTests { + + @Override + protected ClientHttpRequestFactory createRequestFactory() { + return new JdkClientHttpRequestFactory(); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/StreamingJdkClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/StreamingJdkClientHttpRequestFactoryTests.java new file mode 100644 index 000000000000..78089fd62625 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/StreamingJdkClientHttpRequestFactoryTests.java @@ -0,0 +1,28 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +public class StreamingJdkClientHttpRequestFactoryTests extends AbstractHttpRequestFactoryTests { + + @Override + protected ClientHttpRequestFactory createRequestFactory() { + JdkClientHttpRequestFactory requestFactory = new JdkClientHttpRequestFactory(); + requestFactory.setBufferRequestBody(false); + return requestFactory; + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java index b1076e844e20..bc0fb2005d25 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java @@ -46,10 +46,7 @@ import org.springframework.http.MediaType; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; -import org.springframework.http.client.ClientHttpRequestFactory; -import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; -import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; -import org.springframework.http.client.SimpleClientHttpRequestFactory; +import org.springframework.http.client.*; import org.springframework.http.converter.FormHttpMessageConverter; import org.springframework.http.converter.json.MappingJacksonValue; import org.springframework.util.LinkedMultiValueMap; @@ -93,7 +90,8 @@ static Stream> clientHttpRequestFactories() { return Stream.of( named("JDK", new SimpleClientHttpRequestFactory()), named("HttpComponents", new HttpComponentsClientHttpRequestFactory()), - named("OkHttp", new OkHttp3ClientHttpRequestFactory()) + named("OkHttp", new OkHttp3ClientHttpRequestFactory()), + named("JDKClient", new JdkClientHttpRequestFactory()) ); } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java index 419ae2ceb417..8b308248e587 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java @@ -16,13 +16,25 @@ package org.springframework.web.reactive.function; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.util.List; +import java.util.Objects; import java.util.Optional; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -88,6 +100,101 @@ private static BodyExtractor, ReactiveHttpInputMessage> toMono(Resol skipBodyAsMono(inputMessage)); } + /** + * Variant of {@link BodyExtractors#toMono(InputStreamMapper, boolean, int)} with a + * default buffer size and fast failure. + * @param streamMapper the mapper that is reading the body + * @return {@code BodyExtractor} that reads the response body as input stream + * @param the type of the value that is resolved from the returned stream + */ + public static BodyExtractor, ReactiveHttpInputMessage> toMono( + InputStreamMapper streamMapper) { + return toMono(streamMapper, true); + } + + /** + * Variant of {@link BodyExtractors#toMono(InputStreamMapper, boolean, int)} with a + * default buffer size. + * @param streamMapper the mapper that is reading the body + * @param failFast {@code false} if previously read bytes are discarded upon an error + * @return {@code BodyExtractor} that reads the response body as input stream + * @param the type of the value that is resolved from the returned stream + */ + public static BodyExtractor, ReactiveHttpInputMessage> toMono( + InputStreamMapper streamMapper, + boolean failFast) { + return toMono(streamMapper, failFast, 256 * 1024, true); + } + + /** + * Extractor where the response body is processed by reading an input stream of the + * response body. + * @param streamMapper the mapper that is reading the body + * @param failFast {@code false} if previously read bytes are discarded upon an error + * @param maximumMemorySize the amount of memory that is buffered until reading is suspended + * @return {@code BodyExtractor} that reads the response body as input stream + * @param the type of the value that is resolved from the returned stream + */ + public static BodyExtractor, ReactiveHttpInputMessage> toMono( + InputStreamMapper streamMapper, + boolean failFast, + int maximumMemorySize) { + return toMono(streamMapper, failFast, maximumMemorySize, true); + } + + static BodyExtractor, ReactiveHttpInputMessage> toInputStream() { + return toMono(stream -> stream, true, 256 * 1024, false); + } + + private static BodyExtractor, ReactiveHttpInputMessage> toMono( + InputStreamMapper streamMapper, + boolean failFast, + int maximumMemorySize, + boolean close) { + + Assert.notNull(streamMapper, "'streamMapper' must not be null"); + Assert.isTrue(maximumMemorySize > 0, "'maximumMemorySize' must be positive"); + return (inputMessage, context) -> { + FlowBufferInputStream inputStream = new FlowBufferInputStream(maximumMemorySize, failFast); + try { + inputMessage.getBody().subscribe(inputStream); + T value = streamMapper.apply(inputStream); + if (close) { + inputStream.close(); + } + return Mono.just(value); + } catch (Throwable t) { + try { + inputStream.close(); + } catch (Throwable suppressed) { + t.addSuppressed(suppressed); + } + return Mono.error(t); + } + }; + } + + /** + * Variant of {@link BodyExtractors#toMono(InputStreamMapper, boolean, int, boolean)} with a + * default buffer size. + * @param streamSupplier the supplier of the output stream + * @return {@code BodyExtractor} that reads the response body as input stream + */ + public static BodyExtractor, ReactiveHttpInputMessage> toMono( + Supplier streamSupplier) { + + Assert.notNull(streamSupplier, "'streamSupplier' must not be null"); + return (inputMessage, context) -> { + try (OutputStream outputStream = streamSupplier.get()) { + Flux writeResult = DataBufferUtils.write(inputMessage.getBody(), outputStream); + writeResult.blockLast(); + return Mono.empty(); + } catch (Throwable t) { + return Mono.error(t); + } + }; + } + /** * Extractor to decode the input content into {@code Flux}. * @param elementClass the class of the element type to decode to @@ -277,4 +384,219 @@ private static Flux consumeAndCancel(ReactiveHttpInputMessage messag }); } + @FunctionalInterface + public interface InputStreamMapper { + + T apply(InputStream stream) throws IOException; + } + + static class FlowBufferInputStream extends InputStream implements Subscriber { + + private static final Object END = new Object(); + + private final AtomicBoolean closed = new AtomicBoolean(); + + private final BlockingQueue backlog; + + private final int maximumMemorySize; + + private final boolean failFast; + + private final AtomicInteger buffered = new AtomicInteger(); + + @Nullable + private InputStreamWithSize current = new InputStreamWithSize(0, InputStream.nullInputStream()); + + @Nullable + private Subscription subscription; + + FlowBufferInputStream(int maximumMemorySize, boolean failFast) { + this.backlog = new LinkedBlockingDeque<>(); + this.maximumMemorySize = maximumMemorySize; + this.failFast = failFast; + } + + @Override + public void onSubscribe(Subscription subscription) { + this.subscription = subscription; + if (this.closed.get()) { + subscription.cancel(); + } else { + subscription.request(1); + } + } + + @Override + public void onNext(DataBuffer buffer) { + if (this.closed.get()) { + DataBufferUtils.release(buffer); + return; + } + int readableByteCount = buffer.readableByteCount(); + int current = this.buffered.addAndGet(readableByteCount); + if (current < this.maximumMemorySize) { + this.subscription.request(1); + } + InputStream stream = buffer.asInputStream(true); + this.backlog.add(new InputStreamWithSize(readableByteCount, stream)); + if (this.closed.get()) { + DataBufferUtils.release(buffer); + } + } + + @Override + public void onError(Throwable throwable) { + if (failFast) { + Object next; + while ((next = this.backlog.poll()) != null) { + if (next instanceof InputStreamWithSize) { + try { + ((InputStreamWithSize) next).inputStream.close(); + } catch (Throwable t) { + throwable.addSuppressed(t); + } + } + } + } + this.backlog.add(throwable); + } + + @Override + public void onComplete() { + this.backlog.add(END); + } + + private boolean forward() throws IOException { + this.current.inputStream.close(); + try { + Object next = this.backlog.take(); + if (next == END) { + this.current = null; + return true; + } else if (next instanceof RuntimeException) { + close(); + throw (RuntimeException) next; + } else if (next instanceof IOException) { + close(); + throw (IOException) next; + } else if (next instanceof Throwable) { + close(); + throw new IllegalStateException((Throwable) next); + } else { + int buffer = buffered.addAndGet(-this.current.size); + if (buffer < this.maximumMemorySize) { + this.subscription.request(1); + } + this.current = (InputStreamWithSize) next; + return false; + } + } catch (InterruptedException e) { + throw new IllegalStateException(e); + } + } + + @Override + public int read() throws IOException { + if (this.closed.get()) { + throw new IOException("closed"); + } else if (this.current == null) { + return -1; + } + int read; + while ((read = this.current.inputStream.read()) == -1) { + if (forward()) { + return -1; + } + } + return read; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + Objects.checkFromIndexSize(off, len, b.length); + if (this.closed.get()) { + throw new IOException("closed"); + } else if (this.current == null) { + return -1; + } + int sum = 0; + do { + int read = this.current.inputStream.read(b, off + sum, len - sum); + if (read == -1) { + if (this.backlog.isEmpty()) { + return sum; + } else if (forward()) { + return sum == 0 ? -1 : sum; + } + } else { + sum += read; + } + } while (sum < len); + return sum; + } + + @Override + public int available() throws IOException { + if (this.closed.get()) { + throw new IOException("closed"); + } else if (this.current == null) { + return 0; + } + int available = this.current.inputStream.available(); + for (Object value : this.backlog) { + if (value instanceof InputStreamWithSize) { + available += ((InputStreamWithSize) value).inputStream.available(); + } else { + break; + } + } + return available; + } + + @Override + public void close() throws IOException { + if (this.closed.compareAndSet(false, true)) { + if (this.subscription != null) { + this.subscription.cancel(); + } + IOException exception = null; + if (this.current != null) { + try { + this.current.inputStream.close(); + } catch (IOException e) { + exception = e; + } + } + for (Object value : this.backlog) { + if (value instanceof InputStreamWithSize) { + try { + ((InputStreamWithSize) value).inputStream.close(); + } catch (IOException e) { + if (exception == null) { + exception = e; + } else { + exception.addSuppressed(e); + } + } + } + } + if (exception != null) { + throw exception; + } + } + } + } + + static class InputStreamWithSize { + + final int size; + + final InputStream inputStream; + + InputStreamWithSize(int size, InputStream inputStream) { + this.size = size; + this.inputStream = inputStream; + } + } + } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyInserters.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyInserters.java index fe951e0f59bb..2f911bef89be 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyInserters.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyInserters.java @@ -16,10 +16,16 @@ package org.springframework.web.reactive.function; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; import java.util.stream.Collectors; import org.reactivestreams.Publisher; +import org.springframework.core.io.buffer.DataBufferFactory; import reactor.core.publisher.Mono; import org.springframework.core.ParameterizedTypeReference; @@ -40,6 +46,7 @@ import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import reactor.core.publisher.Sinks; /** * Static factory methods for {@link BodyInserter} implementations. @@ -358,6 +365,90 @@ public static > BodyInserter outputMessage.writeWith(publisher); } + /** + * Inserter where the request body is written to an output stream. The stream is closed + * automatically if it is not closed manually. + * @param consumer the consumer that is writing to the output stream + * @return the inserter to write directly to the body via an output stream + */ + public static BodyInserter fromOutputStream( + FromOutputStream consumer) { + + Assert.notNull(consumer, "'publisher' must not be null"); + return (outputMessage, context) -> { + Sinks.Many sink = Sinks.many() + .unicast() + .onBackpressureBuffer(); + + Mono mono = outputMessage.writeWith(sink.asFlux()); + WriterOutputStream outputStream = new WriterOutputStream(outputMessage.bufferFactory(), sink); + try { + consumer.accept(outputStream); + } catch (Throwable t) { + sink.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST); + } + outputStream.close(); + + return mono; + }; + } + + /** + * Variant of {@link BodyInserters#fromInputStream(Supplier, int)} that uses + * a default chunk size. + * @param streamSupplier the supplier that is supplying the stream to write + * @return the inserter to write the request body from a supplied input stream + */ + public static BodyInserter fromInputStream( + Supplier streamSupplier) { + return fromInputStream(streamSupplier, 8192); + } + + /** + * Inserter where the request body is read from an input stream. The supplied + * input stream is closed once it is no longer consumed. + * @param streamSupplier the supplier that is supplying the stream to write + * @param chunkSize the size of each chunk that is buffered before sending + * @return the inserter to write the request body from a supplied input stream + */ + public static BodyInserter fromInputStream( + Supplier streamSupplier, int chunkSize) { + + Assert.notNull(streamSupplier, "'streamSupplier' must not be null"); + Assert.state(chunkSize > 0, "'chunkSize' must be a positive number"); + return (outputMessage, context) -> { + Sinks.Many sink = Sinks.many() + .unicast() + .onBackpressureBuffer(); + + DataBufferFactory factory = outputMessage.bufferFactory(); + Mono mono = outputMessage.writeWith(sink.asFlux()); + try { + InputStream inputStream = streamSupplier.get(); + if (inputStream == null) { + sink.emitError(new NullPointerException("inputStream"), Sinks.EmitFailureHandler.FAIL_FAST); + } else { + try (inputStream) { + int length; + byte[] buffer = new byte[chunkSize]; + while ((length = inputStream.read(buffer)) != -1) { + if (length == 0) { + continue; + } + byte[] wrapped = new byte[length]; + System.arraycopy(buffer, 0, wrapped, 0, length); + sink.emitNext(factory.wrap(wrapped), Sinks.EmitFailureHandler.FAIL_FAST); + } + sink.emitComplete(Sinks.EmitFailureHandler.FAIL_FAST); + } + } + } catch (Throwable t) { + sink.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST); + } + + return mono; + }; + } private static Mono writeWithMessageWriters( M outputMessage, BodyInserter.Context context, Object body, ResolvableType bodyType, @Nullable ReactiveAdapter adapter) { @@ -477,6 +568,20 @@ > MultipartInserter withPublisher(String name, P publi } + /** + * A consumer for an output stream of which the content is written to the request body. + */ + @FunctionalInterface + public interface FromOutputStream { + + /** + * Accepts an output stream which content is written to the request body. + * @param outputStream the output stream that represents the request body + * @throws IOException if an I/O error occurs what aborts the request + */ + void accept(OutputStream outputStream) throws IOException; + + } private static class DefaultFormInserter implements FormInserter { @@ -556,4 +661,45 @@ public Mono insert(ClientHttpRequest outputMessage, Context context) { } } + private static class WriterOutputStream extends OutputStream { + + private final DataBufferFactory factory; + + private final Sinks.Many sink; + + private final AtomicBoolean closed = new AtomicBoolean(); + + private WriterOutputStream(DataBufferFactory factory, Sinks.Many sink) { + this.factory = factory; + this.sink = sink; + } + + @Override + public void write(int b) throws IOException { + if (closed.get()) { + throw new IOException("closed"); + } + DataBuffer buffer = factory.allocateBuffer(1); + buffer.write((byte) (b & 0xFF)); + sink.tryEmitNext(buffer).orThrow(); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + if (closed.get()) { + throw new IOException("closed"); + } + DataBuffer buffer = factory.allocateBuffer(len); + buffer.write(b, off, len); + sink.tryEmitNext(buffer).orThrow(); + } + + @Override + public void close() { + if (closed.compareAndSet(false, true)) { + sink.tryEmitComplete().orThrow(); + } + } + } + } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpRequest.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpRequest.java new file mode 100644 index 000000000000..56db0ac35b87 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpRequest.java @@ -0,0 +1,82 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.client.AbstractBufferingClientHttpRequest; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.net.URI; + +/** + * {@link ClientHttpRequest} implementation based on + * Spring's {@link WebClient}. + * + *

Created via the {@link WebClientHttpRequestFactory}. + */ +final class WebClientHttpRequest extends AbstractBufferingClientHttpRequest { + + private final WebClient webClient; + + private final HttpMethod method; + + private final URI uri; + + WebClientHttpRequest(WebClient client, HttpMethod method, URI uri) { + this.webClient = client; + this.method = method; + this.uri = uri; + } + + + @Override + public HttpMethod getMethod() { + return this.method; + } + + @Override + @Deprecated + public String getMethodValue() { + return this.method.name(); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + protected ClientHttpResponse executeInternal(HttpHeaders headers, byte[] bufferedOutput) throws IOException { + WebClient.RequestHeadersSpec request = this.webClient.method(this.method) + .uri(this.uri) + .bodyValue(bufferedOutput.length == 0 + ? BodyInserters.empty() + : BodyInserters.fromValue(bufferedOutput)); + + request.headers(value -> value.addAll(headers)); + + @SuppressWarnings("deprecation") + ClientResponse response = request.exchange().block(); + return new WebClientHttpResponse(response); + } +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpRequestFactory.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpRequestFactory.java new file mode 100644 index 000000000000..1597cec8c2ca --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpRequestFactory.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function; + +import org.springframework.http.HttpMethod; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.util.Assert; +import org.springframework.web.reactive.function.client.WebClient; + +import java.io.IOException; +import java.net.URI; + +/** + * {@link ClientHttpRequestFactory} implementation that + * uses Spring's {@link WebClient}. + */ +public class WebClientHttpRequestFactory implements ClientHttpRequestFactory { + + private WebClient webClient; + + private boolean bufferRequestBody = true; + + + /** + * Create a new instance of the {@code WebClientHttpRequestFactory} + * with a default {@link WebClient} based on system properties. + */ + public WebClientHttpRequestFactory() { + this.webClient = WebClient.create(); + } + + /** + * Create a new instance of the {@code WebClientHttpRequestFactory} + * with the given {@link WebClient} instance. + * @param webClient the HttpClient instance to use for this request factory + */ + public WebClientHttpRequestFactory(WebClient webClient) { + this.webClient = webClient; + } + + /** + * Set the {@code HttpClient} used for + * {@linkplain #createRequest(URI, HttpMethod) synchronous execution}. + */ + public void setHttpClient(WebClient webClient) { + Assert.notNull(webClient, "WebClient must not be null"); + this.webClient = webClient; + } + + /** + * Return the {@code HttpClient} used for + * {@linkplain #createRequest(URI, HttpMethod) synchronous execution}. + */ + public WebClient getHttpClient() { + return this.webClient; + } + + /** + * Indicates whether this request factory should buffer the request body internally. + *

Default is {@code true}. When sending large amounts of data via POST or PUT, it is + * recommended to change this property to {@code false}, so as not to run out of memory. + * @since 4.0 + */ + public void setBufferRequestBody(boolean bufferRequestBody) { + this.bufferRequestBody = bufferRequestBody; + } + + @Override + public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException { + WebClient client = getHttpClient(); + + if (this.bufferRequestBody) { + return new WebClientHttpRequest(client, httpMethod, uri); + } + else { + return new WebClientStreamingHttpRequest(client, httpMethod, uri); + } + } +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpResponse.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpResponse.java new file mode 100644 index 000000000000..eaa9a1247804 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpResponse.java @@ -0,0 +1,85 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatusCode; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.web.reactive.function.client.ClientResponse; + +import java.io.IOException; +import java.io.InputStream; +/** + * {@link ClientHttpResponse} implementation based on + * Spring's web client. + * + *

Created via the {@link WebClientHttpRequest}. + */ +final class WebClientHttpResponse implements ClientHttpResponse { + + private final ClientResponse response; + + @Nullable + private HttpHeaders headers; + + + WebClientHttpResponse(ClientResponse response) { + this.response = response; + } + + + @Override + public HttpStatusCode getStatusCode() throws IOException { + return this.response.statusCode(); + } + + @Override + @Deprecated + public int getRawStatusCode() throws IOException { + return this.response.statusCode().value(); + } + + @Override + public String getStatusText() throws IOException { + return ""; + } + + @Override + public HttpHeaders getHeaders() { + if (this.headers == null) { + this.headers = this.response.headers().asHttpHeaders(); + } + return this.headers; + } + + @Override + public InputStream getBody() throws IOException { + return this.response.body(BodyExtractors.toInputStream()).block(); + } + + @Override + public void close() { + try { + this.response.releaseBody().block(); + } + catch (Exception ex) { + // Ignore exception on close... + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientStreamingHttpRequest.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientStreamingHttpRequest.java new file mode 100644 index 000000000000..9b68ecb46b75 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientStreamingHttpRequest.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.StreamingHttpOutputMessage; +import org.springframework.http.client.AbstractClientHttpRequest; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.io.OutputStream; +import java.net.URI; + +/** + * {@link ClientHttpRequest} implementation based on + * Spring's web client in streaming mode. + * + *

Created via the {@link WebClientHttpRequestFactory}. + */ +final class WebClientStreamingHttpRequest extends AbstractClientHttpRequest + implements StreamingHttpOutputMessage { + + private final WebClient webClient; + + private final HttpMethod method; + + private final URI uri; + + @Nullable + private Body body; + + WebClientStreamingHttpRequest(WebClient client, HttpMethod method, URI uri) { + this.webClient = client; + this.method = method; + this.uri = uri; + } + + @Override + public HttpMethod getMethod() { + return this.method; + } + + @Override + @Deprecated + public String getMethodValue() { + return this.method.name(); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + public void setBody(Body body) { + assertNotExecuted(); + this.body = body; + } + + @Override + protected OutputStream getBodyInternal(HttpHeaders headers) throws IOException { + throw new UnsupportedOperationException("getBody not supported"); + } + + @Override + protected ClientHttpResponse executeInternal(HttpHeaders headers) throws IOException { + WebClient.RequestHeadersSpec request = this.webClient.method(this.method) + .uri(this.uri) + .bodyValue(this.body == null + ? BodyInserters.empty() + : BodyInserters.fromOutputStream(outputStream -> this.body.writeTo(outputStream))); + + request.headers(value -> value.addAll(headers)); + + @SuppressWarnings("deprecation") + ClientResponse response = request.exchange().block(); + return new WebClientHttpResponse(response); + } +} diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyExtractorsTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyExtractorsTests.java index eea4ef3a8ce7..a460dc94aa74 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyExtractorsTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyExtractorsTests.java @@ -16,6 +16,8 @@ package org.springframework.web.reactive.function; +import java.io.ByteArrayOutputStream; +import java.io.OutputStream; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; @@ -31,6 +33,7 @@ import io.netty.util.IllegalReferenceCountException; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.web.reactive.function.client.WebClient; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -409,6 +412,44 @@ public void toDataBuffers() { .verify(); } + @Test + void toMonoInputStream() { + BodyExtractor, ReactiveHttpInputMessage> extractor = BodyExtractors.toMono( + stream -> new String(stream.readAllBytes(), StandardCharsets.UTF_8)); + + byte[] bytes = "foo".getBytes(StandardCharsets.UTF_8); + DefaultDataBuffer dataBuffer = DefaultDataBufferFactory.sharedInstance.wrap(ByteBuffer.wrap(bytes)); + Flux body = Flux.just(dataBuffer); + + MockServerHttpRequest request = MockServerHttpRequest.post("/").body(body); + Mono result = extractor.extract(request, this.context); + + StepVerifier.create(result) + .expectNext("foo") + .expectComplete() + .verify(); + } + + @Test + void toMonoOutputStream() { + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + BodyExtractor, ReactiveHttpInputMessage> extractor = BodyExtractors.toMono( + () -> outputStream); + + byte[] bytes = "foo".getBytes(StandardCharsets.UTF_8); + DefaultDataBuffer dataBuffer = DefaultDataBufferFactory.sharedInstance.wrap(ByteBuffer.wrap(bytes)); + Flux body = Flux.just(dataBuffer); + + MockServerHttpRequest request = MockServerHttpRequest.post("/").body(body); + Mono result = extractor.extract(request, this.context); + + StepVerifier.create(result) + .expectComplete() + .verify(); + + assertThat(outputStream.toString(StandardCharsets.UTF_8)).isEqualTo("foo"); + } + @Test // SPR-17054 public void unsupportedMediaTypeShouldConsumeAndCancel() { NettyDataBufferFactory factory = new NettyDataBufferFactory(new PooledByteBufAllocator(true)); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyInsertersTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyInsertersTests.java index 71124300d96e..a137ec560180 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyInsertersTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyInsertersTests.java @@ -16,6 +16,7 @@ package org.springframework.web.reactive.function; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.net.URI; import java.nio.ByteBuffer; @@ -437,6 +438,49 @@ public void ofDataBuffers() { .verify(); } + @Test + void fromOutputStream() { + byte[] bytes = "foo".getBytes(UTF_8); + + BodyInserter inserter = BodyInserters.fromOutputStream( + outputStream -> outputStream.write(bytes)); + + MockServerHttpResponse response = new MockServerHttpResponse(); + Mono result = inserter.insert(response, this.context); + StepVerifier.create(result).expectComplete().verify(); + + StepVerifier.create(response.getBody()) + .consumeNextWith(dataBuffer -> { + byte[] resultBytes = new byte[dataBuffer.readableByteCount()]; + dataBuffer.read(resultBytes); + DataBufferUtils.release(dataBuffer); + assertThat(resultBytes).isEqualTo(bytes); + }) + .expectComplete() + .verify(); + } + + @Test + void fromInputStream() { + byte[] bytes = "foo".getBytes(UTF_8); + + BodyInserter inserter = BodyInserters.fromInputStream( + () -> new ByteArrayInputStream(bytes)); + + MockServerHttpResponse response = new MockServerHttpResponse(); + Mono result = inserter.insert(response, this.context); + StepVerifier.create(result).expectComplete().verify(); + + StepVerifier.create(response.getBody()) + .consumeNextWith(dataBuffer -> { + byte[] resultBytes = new byte[dataBuffer.readableByteCount()]; + dataBuffer.read(resultBytes); + DataBufferUtils.release(dataBuffer); + assertThat(resultBytes).isEqualTo(bytes); + }) + .expectComplete() + .verify(); + } interface SafeToSerialize {}