Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 99 additions & 71 deletions src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,41 +38,12 @@ public static void EfficientCopyTo(this Stream input, Stream output)

public static int Read(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
{
try
{
using var manualResetEvent = new ManualResetEventSlim();
var readOperation = stream.BeginRead(
buffer,
offset,
count,
state => ((ManualResetEventSlim)state.AsyncState).Set(),
manualResetEvent);

if (readOperation.IsCompleted || manualResetEvent.Wait(timeout, cancellationToken))
{
return stream.EndRead(readOperation);
}
}
catch (OperationCanceledException)
{
// Have to suppress OperationCanceledException here, it will be thrown after the stream will be disposed.
}
catch (ObjectDisposedException)
{
throw new IOException();
}

try
{
stream.Dispose();
}
catch
{
// Ignore any exceptions
}

cancellationToken.ThrowIfCancellationRequested();
throw new TimeoutException();
return UseStreamWithTimeout(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few naming suggestions:
UseStreamWithTimeout -> rename to ExecuteOperationWithTimeout?
method -> rename to operation?

stream,
(str, state) => str.Read(state.buffer, state.offset, state.count),
(buffer, offset, count),
timeout,
cancellationToken);
}

public static async Task<int> ReadAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
Expand Down Expand Up @@ -219,43 +190,16 @@ public static async Task ReadBytesAsync(this Stream stream, byte[] destination,

public static void Write(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
{
try
{
using var manualResetEvent = new ManualResetEventSlim();
var writeOperation = stream.BeginWrite(
buffer,
offset,
count,
state => ((ManualResetEventSlim)state.AsyncState).Set(),
manualResetEvent);

if (writeOperation.IsCompleted || manualResetEvent.Wait(timeout, cancellationToken))
UseStreamWithTimeout(
stream,
(str, state) =>
{
stream.EndWrite(writeOperation);
return;
}
}
catch (OperationCanceledException)
{
// Have to suppress OperationCanceledException here, it will be thrown after the stream will be disposed.
}
catch (ObjectDisposedException)
{
// It's possible to get ObjectDisposedException when the connection pool was closed with interruptInUseConnections set to true.
throw new IOException();
}

try
{
stream.Dispose();
}
catch
{
// Ignore any exceptions
}

cancellationToken.ThrowIfCancellationRequested();
throw new TimeoutException();
str.Write(state.buffer, state.offset, state.count);
return true;
},
(buffer, offset, count),
timeout,
cancellationToken);
}

public static async Task WriteAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
Expand Down Expand Up @@ -325,5 +269,89 @@ public static async Task WriteBytesAsync(this Stream stream, OperationContext op
count -= bytesToWrite;
}
}

private static TResult UseStreamWithTimeout<TResult, TState>(Stream stream, Func<Stream, TState, TResult> method, TState state, TimeSpan timeout, CancellationToken cancellationToken)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just
byte[] buffer, int offset, int count instead of state, as it's the same in both cases?

{
StreamDisposeCallbackState callbackState = null;
Timer timer = null;
CancellationTokenRegistration cancellationSubscription = default;
if (timeout != Timeout.InfiniteTimeSpan)
{
callbackState = new StreamDisposeCallbackState(stream);
timer = new Timer(DisposeStreamCallback, callbackState, timeout, Timeout.InfiniteTimeSpan);
}

if (cancellationToken.CanBeCanceled)
{
callbackState ??= new StreamDisposeCallbackState(stream);
cancellationSubscription = cancellationToken.Register(DisposeStreamCallback, callbackState);
}

try
{
var result = method(stream, state);
if (callbackState?.TryChangeState(OperationState.Done) == false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe TryChangeStateFromInProgress for more clarity?

{
// if cannot change the state - then the stream was/will be disposed, throw here
throw new IOException();
}

return result;
}
catch (IOException)
{
if (callbackState?.OperationState == OperationState.Cancelled)
{
cancellationToken.ThrowIfCancellationRequested();
throw new TimeoutException();
}

throw;
}
finally
{
timer?.Dispose();
cancellationSubscription.Dispose();
}

static void DisposeStreamCallback(object state)
{
var disposeCallbackState = (StreamDisposeCallbackState)state;
if (!disposeCallbackState.TryChangeState(OperationState.Cancelled))
{
// if cannot change the state - then I/O was already succeeded
return;
}

try
{
disposeCallbackState.Stream.Dispose();
}
catch (Exception)
{
// callbacks should not fail, suppress any exceptions here
}
}
}

private record StreamDisposeCallbackState(Stream Stream)
{
private int _operationState = 0;

public OperationState OperationState
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public OperationState OperationState => (OperationState)_operationState;

{
get => (OperationState)_operationState;
}

public bool TryChangeState(OperationState newState) =>
Interlocked.CompareExchange(ref _operationState, (int)newState, (int)OperationState.InProgress) == (int)OperationState.InProgress;
}

private enum OperationState
{
InProgress = 0,
Done,
Cancelled,
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -811,19 +811,8 @@ public async Task SendMessage_should_put_the_message_on_the_stream_and_raise_the

private void SetupStreamRead(Mock<Stream> streamMock, TaskCompletionSource<int> tcs)
{
streamMock.Setup(s => s.BeginRead(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<AsyncCallback>(), It.IsAny<object>()))
.Returns((byte[] _, int __, int ___, AsyncCallback callback, object state) =>
{
var innerTcs = new TaskCompletionSource<int>(state);
tcs.Task.ContinueWith(t =>
{
innerTcs.TrySetException(t.Exception.InnerException);
callback(innerTcs.Task);
});
return innerTcs.Task;
});
streamMock.Setup(s => s.EndRead(It.IsAny<IAsyncResult>()))
.Returns<IAsyncResult>(x => ((Task<int>)x).GetAwaiter().GetResult());
streamMock.Setup(s => s.Read(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>()))
.Returns((byte[] _, int __, int ___) => tcs.Task.GetAwaiter().GetResult());
streamMock.Setup(s => s.ReadAsync(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<CancellationToken>()))
.Returns(tcs.Task);
streamMock.Setup(s => s.Close()).Callback(() => tcs.TrySetException(new ObjectDisposedException("stream")));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,18 @@ public async Task ReadBytes_with_byte_array_should_have_expected_effect_for_part
var bytes = new byte[] { 1, 2, 3 };
var n = 0;
var position = 0;
Task<int> ReadPartial (byte[] buffer, int offset, int count)
int ReadPartial (byte[] buffer, int offset, int count)
{
var length = partition[n++];
Buffer.BlockCopy(bytes, position, buffer, offset, length);
position += length;
return Task.FromResult(length);
return length;
}

mockStream.Setup(s => s.ReadAsync(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<CancellationToken>()))
.Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => ReadPartial(buffer, offset, count));
mockStream.Setup(s => s.BeginRead(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<AsyncCallback>(), It.IsAny<object>()))
.Returns((byte[] buffer, int offset, int count, AsyncCallback callback, object state) => ReadPartial(buffer, offset, count));
mockStream.Setup(s => s.EndRead(It.IsAny<IAsyncResult>()))
.Returns<IAsyncResult>(x => ((Task<int>)x).GetAwaiter().GetResult());
.Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => Task.FromResult(ReadPartial(buffer, offset, count)));
mockStream.Setup(s => s.Read(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>()))
.Returns((byte[] buffer, int offset, int count) => ReadPartial(buffer, offset, count));
var destination = new byte[3];

if (async)
Expand Down Expand Up @@ -267,20 +265,18 @@ public async Task ReadBytes_with_byte_buffer_should_have_expected_effect_for_par
var destination = new ByteArrayBuffer(new byte[3], 3);
var n = 0;
var position = 0;
Task<int> ReadPartial (byte[] buffer, int offset, int count)
int ReadPartial (byte[] buffer, int offset, int count)
{
var length = partition[n++];
Buffer.BlockCopy(bytes, position, buffer, offset, length);
position += length;
return Task.FromResult(length);
return length;
}

mockStream.Setup(s => s.ReadAsync(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<CancellationToken>()))
.Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => ReadPartial(buffer, offset, count));
mockStream.Setup(s => s.BeginRead(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<AsyncCallback>(), It.IsAny<object>()))
.Returns((byte[] buffer, int offset, int count, AsyncCallback callback, object state) => ReadPartial(buffer, offset, count));
mockStream.Setup(s => s.EndRead(It.IsAny<IAsyncResult>()))
.Returns<IAsyncResult>(x => ((Task<int>)x).GetAwaiter().GetResult());
.Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => Task.FromResult(ReadPartial(buffer, offset, count)));
mockStream.Setup(s => s.Read(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>()))
.Returns((byte[] buffer, int offset, int count) => ReadPartial(buffer, offset, count));

if (async)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public void Heartbeat_should_be_emitted_before_connection_open()

var mockStream = new Mock<Stream>();
mockStream
.Setup(s => s.BeginWrite(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<AsyncCallback>(), It.IsAny<object>()))
.Setup(s => s.Write(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>()))
.Callback(() => EnqueueEvent(HelloReceivedEvent))
.Throws(new Exception("Stream is closed."));

Expand Down
Loading