Skip to content

Commit 35b99e4

Browse files
authored
Expose HttpResponse PipeWriter in Kestrel (#7110)
1 parent 7b3149a commit 35b99e4

File tree

55 files changed

+3291
-522
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+3291
-522
lines changed

src/Http/Http.Abstractions/src/Extensions/HttpResponseWritingExtensions.cs

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

44
using System;
5+
using System.IO.Pipelines;
56
using System.Text;
67
using System.Threading;
78
using System.Threading.Tasks;
@@ -60,8 +61,73 @@ public static class HttpResponseWritingExtensions
6061
throw new ArgumentNullException(nameof(encoding));
6162
}
6263

63-
byte[] data = encoding.GetBytes(text);
64-
return response.Body.WriteAsync(data, 0, data.Length, cancellationToken);
64+
// Need to call StartAsync before GetMemory/GetSpan
65+
if (!response.HasStarted)
66+
{
67+
var startAsyncTask = response.StartAsync(cancellationToken);
68+
if (!startAsyncTask.IsCompletedSuccessfully)
69+
{
70+
return StartAndWriteAsyncAwaited(response, text, encoding, cancellationToken, startAsyncTask);
71+
}
72+
}
73+
74+
Write(response, text, encoding);
75+
76+
var flushAsyncTask = response.BodyPipe.FlushAsync(cancellationToken);
77+
if (flushAsyncTask.IsCompletedSuccessfully)
78+
{
79+
// Most implementations of ValueTask reset state in GetResult, so call it before returning a completed task.
80+
flushAsyncTask.GetAwaiter().GetResult();
81+
return Task.CompletedTask;
82+
}
83+
84+
return flushAsyncTask.AsTask();
85+
}
86+
87+
private static async Task StartAndWriteAsyncAwaited(this HttpResponse response, string text, Encoding encoding, CancellationToken cancellationToken, Task startAsyncTask)
88+
{
89+
await startAsyncTask;
90+
Write(response, text, encoding);
91+
await response.BodyPipe.FlushAsync(cancellationToken);
92+
}
93+
94+
private static void Write(this HttpResponse response, string text, Encoding encoding)
95+
{
96+
var pipeWriter = response.BodyPipe;
97+
var encodedLength = encoding.GetByteCount(text);
98+
var destination = pipeWriter.GetSpan(encodedLength);
99+
100+
if (encodedLength <= destination.Length)
101+
{
102+
// Just call Encoding.GetBytes if everything will fit into a single segment.
103+
var bytesWritten = encoding.GetBytes(text, destination);
104+
pipeWriter.Advance(bytesWritten);
105+
}
106+
else
107+
{
108+
WriteMutliSegmentEncoded(pipeWriter, text, encoding, destination, encodedLength);
109+
}
110+
}
111+
112+
private static void WriteMutliSegmentEncoded(PipeWriter writer, string text, Encoding encoding, Span<byte> destination, int encodedLength)
113+
{
114+
var encoder = encoding.GetEncoder();
115+
var source = text.AsSpan();
116+
var completed = false;
117+
var totalBytesUsed = 0;
118+
119+
// This may be a bug, but encoder.Convert returns completed = true for UTF7 too early.
120+
// Therefore, we check encodedLength - totalBytesUsed too.
121+
while (!completed || encodedLength - totalBytesUsed != 0)
122+
{
123+
encoder.Convert(source, destination, flush: source.Length == 0, out var charsUsed, out var bytesUsed, out completed);
124+
totalBytesUsed += bytesUsed;
125+
126+
writer.Advance(bytesUsed);
127+
source = source.Slice(charsUsed);
128+
129+
destination = writer.GetSpan(encodedLength - totalBytesUsed);
130+
}
65131
}
66132
}
67-
}
133+
}

src/Http/Http.Abstractions/test/HttpResponseWritingExtensionsTests.cs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
// Copyright (c) .NET Foundation. All rights reserved.
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

4+
using System;
45
using System.IO;
6+
using System.IO.Pipelines;
7+
using System.IO.Pipelines.Tests;
8+
using System.Text;
59
using System.Threading.Tasks;
610
using Xunit;
711

@@ -28,6 +32,65 @@ public async Task WritingText_MultipleWrites()
2832
Assert.Equal(22, context.Response.Body.Length);
2933
}
3034

35+
[Theory]
36+
[MemberData(nameof(Encodings))]
37+
public async Task WritingTextThatRequiresMultipleSegmentsWorks(Encoding encoding)
38+
{
39+
// Need to change the StreamPipeWriter with a capped MemoryPool
40+
var memoryPool = new TestMemoryPool(maxBufferSize: 16);
41+
var outputStream = new MemoryStream();
42+
var streamPipeWriter = new StreamPipeWriter(outputStream, minimumSegmentSize: 0, memoryPool);
43+
44+
HttpContext context = new DefaultHttpContext();
45+
context.Response.BodyPipe = streamPipeWriter;
46+
47+
var inputString = "昨日すき焼きを食べました";
48+
var expected = encoding.GetBytes(inputString);
49+
await context.Response.WriteAsync(inputString, encoding);
50+
51+
outputStream.Position = 0;
52+
var actual = new byte[expected.Length];
53+
var length = outputStream.Read(actual);
54+
55+
var res1 = encoding.GetString(actual);
56+
var res2 = encoding.GetString(expected);
57+
Assert.Equal(expected.Length, length);
58+
Assert.Equal(expected, actual);
59+
streamPipeWriter.Complete();
60+
}
61+
62+
[Theory]
63+
[MemberData(nameof(Encodings))]
64+
public async Task WritingTextWithPassedInEncodingWorks(Encoding encoding)
65+
{
66+
HttpContext context = CreateRequest();
67+
68+
var inputString = "昨日すき焼きを食べました";
69+
var expected = encoding.GetBytes(inputString);
70+
await context.Response.WriteAsync(inputString, encoding);
71+
72+
context.Response.Body.Position = 0;
73+
var actual = new byte[expected.Length * 2];
74+
var length = context.Response.Body.Read(actual);
75+
76+
var actualShortened = new byte[length];
77+
Array.Copy(actual, actualShortened, length);
78+
79+
Assert.Equal(expected.Length, length);
80+
Assert.Equal(expected, actualShortened);
81+
}
82+
83+
public static TheoryData<Encoding> Encodings =>
84+
new TheoryData<Encoding>
85+
{
86+
{ Encoding.ASCII },
87+
{ Encoding.BigEndianUnicode },
88+
{ Encoding.Unicode },
89+
{ Encoding.UTF32 },
90+
{ Encoding.UTF7 },
91+
{ Encoding.UTF8 }
92+
};
93+
3194
private HttpContext CreateRequest()
3295
{
3396
HttpContext context = new DefaultHttpContext();

src/Http/Http.Abstractions/test/Microsoft.AspNetCore.Http.Abstractions.Tests.csproj

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
<TargetFramework>netcoreapp3.0</TargetFramework>
55
</PropertyGroup>
66

7+
<ItemGroup>
8+
<ProjectReference Include="..\..\Http\test\Microsoft.AspNetCore.Http.Tests.csproj" />
9+
</ItemGroup>
10+
711
<ItemGroup>
812
<Reference Include="Microsoft.AspNetCore.Http" />
913
</ItemGroup>

src/Http/Http.Features/src/FeatureReferences.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,4 @@ private TFeature UpdateCached<TFeature, TState>(ref TFeature cached, TState stat
9595
public TFeature Fetch<TFeature>(ref TFeature cached, Func<IFeatureCollection, TFeature> factory)
9696
where TFeature : class => Fetch(ref cached, Collection, factory);
9797
}
98-
}
98+
}

src/Http/Http/src/Internal/DefaultHttpResponse.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ public override Task StartAsync(CancellationToken cancellationToken = default)
151151
return HttpResponseFeature.Body.FlushAsync(cancellationToken);
152152
}
153153

154-
return HttpResponseStartFeature.StartAsync();
154+
return HttpResponseStartFeature.StartAsync(cancellationToken);
155155
}
156156

157157
struct FeatureInterfaces

src/Http/Http/src/ReadOnlyPipeStream.cs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ namespace System.IO.Pipelines
1313
/// </summary>
1414
public class ReadOnlyPipeStream : Stream
1515
{
16-
private readonly PipeReader _pipeReader;
1716
private bool _allowSynchronousIO = true;
1817

1918
/// <summary>
@@ -33,7 +32,7 @@ public ReadOnlyPipeStream(PipeReader pipeReader) :
3332
public ReadOnlyPipeStream(PipeReader pipeReader, bool allowSynchronousIO)
3433
{
3534
_allowSynchronousIO = allowSynchronousIO;
36-
_pipeReader = pipeReader;
35+
InnerPipeReader = pipeReader;
3736
}
3837

3938
/// <inheritdoc />
@@ -62,6 +61,8 @@ public override int WriteTimeout
6261
set => throw new NotSupportedException();
6362
}
6463

64+
public PipeReader InnerPipeReader { get; }
65+
6566
/// <inheritdoc />
6667
public override void Write(byte[] buffer, int offset, int count)
6768
=> throw new NotSupportedException();
@@ -160,7 +161,7 @@ private async ValueTask<int> ReadAsyncInternal(Memory<byte> buffer, Cancellation
160161
{
161162
while (true)
162163
{
163-
var result = await _pipeReader.ReadAsync(cancellationToken);
164+
var result = await InnerPipeReader.ReadAsync(cancellationToken);
164165
var readableBuffer = result.Buffer;
165166
var readableBufferLength = readableBuffer.Length;
166167

@@ -186,7 +187,7 @@ private async ValueTask<int> ReadAsyncInternal(Memory<byte> buffer, Cancellation
186187
}
187188
finally
188189
{
189-
_pipeReader.AdvanceTo(consumed);
190+
InnerPipeReader.AdvanceTo(consumed);
190191
}
191192
}
192193
}
@@ -211,7 +212,7 @@ private async Task CopyToAsyncInternal(Stream destination, CancellationToken can
211212
{
212213
while (true)
213214
{
214-
var result = await _pipeReader.ReadAsync(cancellationToken);
215+
var result = await InnerPipeReader.ReadAsync(cancellationToken);
215216
var readableBuffer = result.Buffer;
216217
var readableBufferLength = readableBuffer.Length;
217218

@@ -232,7 +233,7 @@ private async Task CopyToAsyncInternal(Stream destination, CancellationToken can
232233
}
233234
finally
234235
{
235-
_pipeReader.AdvanceTo(readableBuffer.End);
236+
InnerPipeReader.AdvanceTo(readableBuffer.End);
236237
}
237238
}
238239
}

src/Http/Http/src/StreamPipeReader.cs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ public class StreamPipeReader : PipeReader, IDisposable
1717
{
1818
private readonly int _minimumSegmentSize;
1919
private readonly int _minimumReadThreshold;
20-
private readonly Stream _readingStream;
2120
private readonly MemoryPool<byte> _pool;
2221

2322
private CancellationTokenSource _internalTokenSource;
@@ -42,15 +41,14 @@ public StreamPipeReader(Stream readingStream)
4241
{
4342
}
4443

45-
4644
/// <summary>
4745
/// Creates a new StreamPipeReader.
4846
/// </summary>
4947
/// <param name="readingStream">The stream to read from.</param>
5048
/// <param name="options">The options to use.</param>
5149
public StreamPipeReader(Stream readingStream, StreamPipeReaderOptions options)
5250
{
53-
_readingStream = readingStream ?? throw new ArgumentNullException(nameof(readingStream));
51+
InnerStream = readingStream ?? throw new ArgumentNullException(nameof(readingStream));
5452

5553
if (options == null)
5654
{
@@ -70,7 +68,7 @@ public StreamPipeReader(Stream readingStream, StreamPipeReaderOptions options)
7068
/// <summary>
7169
/// Gets the inner stream that is being read from.
7270
/// </summary>
73-
public Stream InnerStream => _readingStream;
71+
public Stream InnerStream { get; }
7472

7573
/// <inheritdoc />
7674
public override void AdvanceTo(SequencePosition consumed)
@@ -235,7 +233,7 @@ public override async ValueTask<ReadResult> ReadAsync(CancellationToken cancella
235233
{
236234
AllocateReadTail();
237235
#if NETCOREAPP3_0
238-
var length = await _readingStream.ReadAsync(_readTail.AvailableMemory.Slice(_readTail.End), tokenSource.Token);
236+
var length = await InnerStream.ReadAsync(_readTail.AvailableMemory.Slice(_readTail.End), tokenSource.Token);
239237
#elif NETSTANDARD2_0
240238
if (!MemoryMarshal.TryGetArray<byte>(_readTail.AvailableMemory.Slice(_readTail.End), out var arraySegment))
241239
{

src/Http/Http/src/StreamPipeWriter.cs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ namespace System.IO.Pipelines
1414
public class StreamPipeWriter : PipeWriter, IDisposable
1515
{
1616
private readonly int _minimumSegmentSize;
17-
private readonly Stream _writingStream;
1817
private int _bytesWritten;
1918

2019
private List<CompletedBuffer> _completedSegments;
@@ -53,14 +52,14 @@ public StreamPipeWriter(Stream writingStream) : this(writingStream, 4096)
5352
public StreamPipeWriter(Stream writingStream, int minimumSegmentSize, MemoryPool<byte> pool = null)
5453
{
5554
_minimumSegmentSize = minimumSegmentSize;
56-
_writingStream = writingStream;
55+
InnerStream = writingStream;
5756
_pool = pool ?? MemoryPool<byte>.Shared;
5857
}
5958

6059
/// <summary>
6160
/// Gets the inner stream that is being written to.
6261
/// </summary>
63-
public Stream InnerStream => _writingStream;
62+
public Stream InnerStream { get; }
6463

6564
/// <inheritdoc />
6665
public override void Advance(int count)
@@ -180,10 +179,10 @@ private async ValueTask<FlushResult> FlushAsyncInternal(CancellationToken cancel
180179
{
181180
var segment = _completedSegments[0];
182181
#if NETCOREAPP3_0
183-
await _writingStream.WriteAsync(segment.Buffer.Slice(0, segment.Length), localToken);
182+
await InnerStream.WriteAsync(segment.Buffer.Slice(0, segment.Length), localToken);
184183
#elif NETSTANDARD2_0
185184
MemoryMarshal.TryGetArray<byte>(segment.Buffer, out var arraySegment);
186-
await _writingStream.WriteAsync(arraySegment.Array, 0, segment.Length, localToken);
185+
await InnerStream.WriteAsync(arraySegment.Array, 0, segment.Length, localToken);
187186
#else
188187
#error Target frameworks need to be updated.
189188
#endif
@@ -196,18 +195,18 @@ private async ValueTask<FlushResult> FlushAsyncInternal(CancellationToken cancel
196195
if (!_currentSegment.IsEmpty)
197196
{
198197
#if NETCOREAPP3_0
199-
await _writingStream.WriteAsync(_currentSegment.Slice(0, _position), localToken);
198+
await InnerStream.WriteAsync(_currentSegment.Slice(0, _position), localToken);
200199
#elif NETSTANDARD2_0
201200
MemoryMarshal.TryGetArray<byte>(_currentSegment, out var arraySegment);
202-
await _writingStream.WriteAsync(arraySegment.Array, 0, _position, localToken);
201+
await InnerStream.WriteAsync(arraySegment.Array, 0, _position, localToken);
203202
#else
204203
#error Target frameworks need to be updated.
205204
#endif
206205
_bytesWritten -= _position;
207206
_position = 0;
208207
}
209208

210-
await _writingStream.FlushAsync(localToken);
209+
await InnerStream.FlushAsync(localToken);
211210

212211
return new FlushResult(isCanceled: false, _isCompleted);
213212
}

0 commit comments

Comments
 (0)