diff --git a/eng/Baseline.Designer.props b/eng/Baseline.Designer.props index 29e2efcf98b8..4ca09124b681 100644 --- a/eng/Baseline.Designer.props +++ b/eng/Baseline.Designer.props @@ -463,6 +463,7 @@ + @@ -472,6 +473,7 @@ + diff --git a/eng/PatchConfig.props b/eng/PatchConfig.props index 513d6f73b15c..48bab5aea131 100644 --- a/eng/PatchConfig.props +++ b/eng/PatchConfig.props @@ -46,6 +46,8 @@ Later on, this will be checked using this condition: + Microsoft.AspNetCore.Http.Connections; + Microsoft.AspNetCore.SignalR.Core; diff --git a/src/SignalR/clients/ts/FunctionalTests/selenium/run-tests.ts b/src/SignalR/clients/ts/FunctionalTests/selenium/run-tests.ts index c74603b12c56..6f548f153484 100644 --- a/src/SignalR/clients/ts/FunctionalTests/selenium/run-tests.ts +++ b/src/SignalR/clients/ts/FunctionalTests/selenium/run-tests.ts @@ -1,7 +1,8 @@ import { ChildProcess, spawn } from "child_process"; -import * as fs from "fs"; +import * as _fs from "fs"; import { EOL } from "os"; import * as path from "path"; +import { promisify } from "util"; import { PassThrough, Readable } from "stream"; import { run } from "../../webdriver-tap-runner/lib"; @@ -9,6 +10,16 @@ import { run } from "../../webdriver-tap-runner/lib"; import * as _debug from "debug"; const debug = _debug("signalr-functional-tests:run"); +const ARTIFACTS_DIR = path.resolve(__dirname, "..", "..", "..", "..", "artifacts"); +const LOGS_DIR = path.resolve(ARTIFACTS_DIR, "logs"); + +// Promisify things from fs we want to use. +const fs = { + createWriteStream: _fs.createWriteStream, + exists: promisify(_fs.exists), + mkdir: promisify(_fs.mkdir), +}; + process.on("unhandledRejection", (reason) => { console.error(`Unhandled promise rejection: ${reason}`); process.exit(1); @@ -102,6 +113,13 @@ if (chromePath) { try { const serverPath = path.resolve(__dirname, "..", "bin", configuration, "netcoreapp2.1", "FunctionalTests.dll"); + if (!await fs.exists(ARTIFACTS_DIR)) { + await fs.mkdir(ARTIFACTS_DIR); + } + if (!await fs.exists(LOGS_DIR)) { + await fs.mkdir(LOGS_DIR); + } + debug(`Launching Functional Test Server: ${serverPath}`); const dotnet = spawn("dotnet", [serverPath], { env: { @@ -117,6 +135,9 @@ if (chromePath) { } } + const logStream = fs.createWriteStream(path.resolve(LOGS_DIR, "ts.functionaltests.dotnet.log")); + dotnet.stdout.pipe(logStream); + process.on("SIGINT", cleanup); process.on("exit", cleanup); diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs index 045e821ee108..cf04d5ff484a 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs @@ -27,6 +27,8 @@ public class HttpConnectionContext : ConnectionContext, IHttpTransportFeature, IConnectionInherentKeepAliveFeature { + private static long _tenSeconds = TimeSpan.FromSeconds(10).Ticks; + private readonly object _itemsLock = new object(); private readonly object _heartbeatLock = new object(); private List<(Action handler, object state)> _heartbeatHandlers; @@ -35,6 +37,13 @@ public class HttpConnectionContext : ConnectionContext, private IDuplexPipe _application; private IDictionary _items; + private CancellationTokenSource _sendCts; + private bool _activeSend; + private long _startedSendTime; + private readonly object _sendingLock = new object(); + + internal CancellationToken SendingToken { get; private set; } + // This tcs exists so that multiple calls to DisposeAsync all wait asynchronously // on the same task private readonly TaskCompletionSource _disposeTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); @@ -274,24 +283,45 @@ private async Task WaitOnTasks(Task applicationTask, Task transportTask, bool cl // Cancel any pending flushes from back pressure Application?.Output.CancelPendingFlush(); - // Shutdown both sides and wait for nothing - Transport?.Output.Complete(applicationTask.Exception?.InnerException); - Application?.Output.Complete(transportTask.Exception?.InnerException); - + // Normally it isn't safe to try and acquire this lock because the Send can hold onto it for a long time if there is backpressure + // It is safe to wait for this lock now because the Send will be in one of 4 states + // 1. In the middle of a write which is in the middle of being canceled by the CancelPendingFlush above, when it throws + // an OperationCanceledException it will complete the PipeWriter which will make any other Send waiting on the lock + // throw an InvalidOperationException if they call Write + // 2. About to write and see that there is a pending cancel from the CancelPendingFlush, go to 1 to see what happens + // 3. Enters the Send and sees the Dispose state from DisposeAndRemoveAsync and releases the lock + // 4. No Send in progress + await WriteLock.WaitAsync(); try { - Log.WaitingForTransportAndApplication(_logger, TransportType); - // A poorly written application *could* in theory get stuck forever and it'll show up as a memory leak - await Task.WhenAll(applicationTask, transportTask); + // Complete the applications read loop + Application?.Output.Complete(transportTask.Exception?.InnerException); } finally { - Log.TransportAndApplicationComplete(_logger, TransportType); - - // Close the reading side after both sides run - Application?.Input.Complete(); - Transport?.Input.Complete(); + WriteLock.Release(); } + + Application?.Input.CancelPendingRead(); + + await transportTask.NoThrow(); + Application?.Input.Complete(); + + Log.WaitingForTransportAndApplication(_logger, TransportType); + + // A poorly written application *could* in theory get stuck forever and it'll show up as a memory leak + // Wait for application so we can complete the writer safely + await applicationTask.NoThrow(); + Log.TransportAndApplicationComplete(_logger, TransportType); + + // Shutdown application side now that it's finished + Transport?.Output.Complete(applicationTask.Exception?.InnerException); + + // Close the reading side after both sides run + Transport?.Input.Complete(); + + // Observe exceptions + await Task.WhenAll(transportTask, applicationTask); } // Notify all waiters that we're done disposing @@ -311,6 +341,43 @@ private async Task WaitOnTasks(Task applicationTask, Task transportTask, bool cl } } + internal void StartSendCancellation() + { + lock (_sendingLock) + { + if (_sendCts == null || _sendCts.IsCancellationRequested) + { + _sendCts = new CancellationTokenSource(); + SendingToken = _sendCts.Token; + } + + _startedSendTime = DateTime.UtcNow.Ticks; + _activeSend = true; + } + } + + internal void TryCancelSend(long currentTicks) + { + lock (_sendingLock) + { + if (_activeSend) + { + if (currentTicks - _startedSendTime > _tenSeconds) + { + _sendCts.Cancel(); + } + } + } + } + + internal void StopSendCancellation() + { + lock (_sendingLock) + { + _activeSend = false; + } + } + private static class Log { private static readonly Action _disposingConnection = diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index 50910bccfec1..0449e193f11d 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -144,7 +144,7 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti connection.SupportedFormats = TransferFormat.Text; // We only need to provide the Input channel since writing to the application is handled through /send. - var sse = new ServerSentEventsTransport(connection.Application.Input, connection.ConnectionId, _loggerFactory); + var sse = new ServerSentEventsTransport(connection.Application.Input, connection.ConnectionId, connection, _loggerFactory); await DoPersistentConnection(connectionDelegate, sse, context, connection); } @@ -264,7 +264,7 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti context.Response.RegisterForDispose(timeoutSource); context.Response.RegisterForDispose(tokenSource); - var longPolling = new LongPollingTransport(timeoutSource.Token, connection.Application.Input, _loggerFactory); + var longPolling = new LongPollingTransport(timeoutSource.Token, connection.Application.Input, _loggerFactory, connection); // Start the transport connection.TransportTask = longPolling.ProcessRequestAsync(context, tokenSource.Token); @@ -291,7 +291,9 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti connection.Transport.Output.Complete(connection.ApplicationTask.Exception); // Wait for the transport to run - await connection.TransportTask; + // Ignore exceptions, it has been logged if there is one and the application has finished + // So there is no one to give the exception to + await connection.TransportTask.NoThrow(); // If the status code is a 204 it means the connection is done if (context.Response.StatusCode == StatusCodes.Status204NoContent) @@ -307,6 +309,18 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti pollAgain = false; } } + else if (connection.TransportTask.IsFaulted || connection.TransportTask.IsCanceled) + { + // Cancel current request to release any waiting poll and let dispose aquire the lock + currentRequestTcs.TrySetCanceled(); + + // We should be able to safely dispose because there's no more data being written + // We don't need to wait for close here since we've already waited for both sides + await _manager.DisposeAndRemoveAsync(connection, closeGracefully: false); + + // Don't poll again if we've removed the connection completely + pollAgain = false; + } else if (context.Response.StatusCode == StatusCodes.Status204NoContent) { // Don't poll if the transport task was canceled @@ -511,6 +525,14 @@ private async Task ProcessSend(HttpContext context, HttpConnectionDispatcherOpti context.Response.StatusCode = StatusCodes.Status404NotFound; context.Response.ContentType = "text/plain"; + + // There are no writes anymore (since this is the write "loop") + // So it is safe to complete the writer + // We complete the writer here because we already have the WriteLock acquired + // and it's unsafe to complete outside of the lock + // Other code isn't guaranteed to be able to acquire the lock before another write + // even if CancelPendingFlush is called, and the other write could hang if there is backpressure + connection.Application.Output.Complete(); return; } @@ -549,11 +571,8 @@ private async Task ProcessDeleteAsync(HttpContext context) Log.TerminatingConection(_logger); - // Complete the receiving end of the pipe - connection.Application.Output.Complete(); - - // Dispose the connection gracefully, but don't wait for it. We assign it here so we can wait in tests - connection.DisposeAndRemoveTask = _manager.DisposeAndRemoveAsync(connection, closeGracefully: true); + // Dispose the connection, but don't wait for it. We assign it here so we can wait in tests + connection.DisposeAndRemoveTask = _manager.DisposeAndRemoveAsync(connection, closeGracefully: false); context.Response.StatusCode = StatusCodes.Status202Accepted; context.Response.ContentType = "text/plain"; diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs index 43e598274844..d6860b7ad196 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs @@ -30,6 +30,7 @@ public partial class HttpConnectionManager private readonly TimerAwaitable _nextHeartbeat; private readonly ILogger _logger; private readonly ILogger _connectionLogger; + private readonly bool _useSendTimeout = true; public HttpConnectionManager(ILoggerFactory loggerFactory, IApplicationLifetime appLifetime) { @@ -38,6 +39,11 @@ public HttpConnectionManager(ILoggerFactory loggerFactory, IApplicationLifetime appLifetime.ApplicationStarted.Register(() => Start()); appLifetime.ApplicationStopping.Register(() => CloseConnections()); _nextHeartbeat = new TimerAwaitable(_heartbeatTickRate, _heartbeatTickRate); + + if (AppContext.TryGetSwitch("Microsoft.AspNetCore.Http.Connections.DoNotUseSendTimeout", out var timeoutDisabled)) + { + _useSendTimeout = !timeoutDisabled; + } } public void Start() @@ -156,9 +162,10 @@ public async Task ScanAsync() connection.StateLock.Release(); } + var utcNow = DateTimeOffset.UtcNow; // Once the decision has been made to dispose we don't check the status again // But don't clean up connections while the debugger is attached. - if (!Debugger.IsAttached && status == HttpConnectionStatus.Inactive && (DateTimeOffset.UtcNow - lastSeenUtc).TotalSeconds > 5) + if (!Debugger.IsAttached && status == HttpConnectionStatus.Inactive && (utcNow - lastSeenUtc).TotalSeconds > 5) { Log.ConnectionTimedOut(_logger, connection.ConnectionId); HttpConnectionsEventSource.Log.ConnectionTimedOut(connection.ConnectionId); @@ -170,6 +177,11 @@ public async Task ScanAsync() } else { + if (!Debugger.IsAttached && _useSendTimeout) + { + connection.TryCancelSend(utcNow.Ticks); + } + // Tick the heartbeat, if the connection is still active connection.TickHeartbeat(); } diff --git a/src/SignalR/common/Http.Connections/src/Internal/TaskExtensions.cs b/src/SignalR/common/Http.Connections/src/Internal/TaskExtensions.cs new file mode 100644 index 000000000000..a901379b75c7 --- /dev/null +++ b/src/SignalR/common/Http.Connections/src/Internal/TaskExtensions.cs @@ -0,0 +1,27 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Runtime.CompilerServices; + +namespace System.Threading.Tasks +{ + internal static class TaskExtensions + { + public static async Task NoThrow(this Task task) + { + await new NoThrowAwaiter(task); + } + } + + internal readonly struct NoThrowAwaiter : ICriticalNotifyCompletion + { + private readonly Task _task; + public NoThrowAwaiter(Task task) { _task = task; } + public NoThrowAwaiter GetAwaiter() => this; + public bool IsCompleted => _task.IsCompleted; + // Observe exception + public void GetResult() { _ = _task.Exception; } + public void OnCompleted(Action continuation) => _task.GetAwaiter().OnCompleted(continuation); + public void UnsafeOnCompleted(Action continuation) => OnCompleted(continuation); + } +} diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/LongPollingTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/LongPollingTransport.cs index cf5638d74d94..15eb3026050b 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/LongPollingTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/LongPollingTransport.cs @@ -16,6 +16,7 @@ public class LongPollingTransport : IHttpTransport private readonly PipeReader _application; private readonly ILogger _logger; private readonly CancellationToken _timeoutToken; + private readonly HttpConnectionContext _connection; public LongPollingTransport(CancellationToken timeoutToken, PipeReader application, ILoggerFactory loggerFactory) { @@ -24,6 +25,12 @@ public LongPollingTransport(CancellationToken timeoutToken, PipeReader applicati _logger = loggerFactory.CreateLogger(); } + internal LongPollingTransport(CancellationToken timeoutToken, PipeReader application, ILoggerFactory loggerFactory, HttpConnectionContext connection) + : this(timeoutToken, application, loggerFactory) + { + _connection = connection; + } + public async Task ProcessRequestAsync(HttpContext context, CancellationToken token) { try @@ -31,37 +38,40 @@ public async Task ProcessRequestAsync(HttpContext context, CancellationToken tok var result = await _application.ReadAsync(token); var buffer = result.Buffer; - if (buffer.IsEmpty && result.IsCompleted) + try { - Log.LongPolling204(_logger); - context.Response.ContentType = "text/plain"; - context.Response.StatusCode = StatusCodes.Status204NoContent; - return; - } + if (buffer.IsEmpty && (result.IsCompleted || result.IsCanceled)) + { + Log.LongPolling204(_logger); + context.Response.ContentType = "text/plain"; + context.Response.StatusCode = StatusCodes.Status204NoContent; + return; + } - // We're intentionally not checking cancellation here because we need to drain messages we've got so far, - // but it's too late to emit the 204 required by being canceled. + // We're intentionally not checking cancellation here because we need to drain messages we've got so far, + // but it's too late to emit the 204 required by being canceled. - Log.LongPollingWritingMessage(_logger, buffer.Length); + Log.LongPollingWritingMessage(_logger, buffer.Length); - context.Response.ContentLength = buffer.Length; - context.Response.ContentType = "application/octet-stream"; + context.Response.ContentLength = buffer.Length; + context.Response.ContentType = "application/octet-stream"; - try - { - await context.Response.Body.WriteAsync(buffer); + _connection?.StartSendCancellation(); + await context.Response.Body.WriteAsync(buffer, _connection?.SendingToken ?? default); } finally { + _connection?.StopSendCancellation(); _application.AdvanceTo(buffer.End); } } catch (OperationCanceledException) { - // 3 cases: + // 4 cases: // 1 - Request aborted, the client disconnected (no response) // 2 - The poll timeout is hit (204) - // 3 - A new request comes in and cancels this request (204) + // 3 - SendingToken was canceled, abort the connection + // 4 - A new request comes in and cancels this request (204) // Case 1 if (context.RequestAborted.IsCancellationRequested) @@ -79,9 +89,16 @@ public async Task ProcessRequestAsync(HttpContext context, CancellationToken tok context.Response.ContentType = "text/plain"; context.Response.StatusCode = StatusCodes.Status200OK; } - else + else if (_connection?.SendingToken.IsCancellationRequested == true) { // Case 3 + context.Response.ContentType = "text/plain"; + context.Response.StatusCode = StatusCodes.Status204NoContent; + throw; + } + else + { + // Case 4 Log.LongPolling204(_logger); context.Response.ContentType = "text/plain"; context.Response.StatusCode = StatusCodes.Status204NoContent; diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsMessageFormatter.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsMessageFormatter.cs index 0122344e0eb8..e0d089832471 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsMessageFormatter.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsMessageFormatter.cs @@ -4,6 +4,7 @@ using System; using System.Buffers; using System.IO; +using System.Threading; using System.Threading.Tasks; namespace Microsoft.AspNetCore.Http.Connections.Internal @@ -15,19 +16,24 @@ public static class ServerSentEventsMessageFormatter private const byte LineFeed = (byte)'\n'; - public static async Task WriteMessageAsync(ReadOnlySequence payload, Stream output) + public static Task WriteMessageAsync(ReadOnlySequence payload, Stream output) + { + return WriteMessageAsync(payload, output, default); + } + + internal static async Task WriteMessageAsync(ReadOnlySequence payload, Stream output, CancellationToken token) { // Payload does not contain a line feed so write it directly to output if (payload.PositionOf(LineFeed) == null) { if (payload.Length > 0) { - await output.WriteAsync(DataPrefix, 0, DataPrefix.Length); - await output.WriteAsync(payload); - await output.WriteAsync(Newline, 0, Newline.Length); + await output.WriteAsync(DataPrefix, 0, DataPrefix.Length, token); + await output.WriteAsync(payload, token); + await output.WriteAsync(Newline, 0, Newline.Length, token); } - await output.WriteAsync(Newline, 0, Newline.Length); + await output.WriteAsync(Newline, 0, Newline.Length, token); return; } @@ -37,7 +43,7 @@ public static async Task WriteMessageAsync(ReadOnlySequence payload, Strea await WriteMessageToMemory(ms, payload); ms.Position = 0; - await ms.CopyToAsync(output); + await ms.CopyToAsync(output, bufferSize: 81920, token); } /// diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsTransport.cs index 1c0dd85719c0..8ffacd9b2a25 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsTransport.cs @@ -15,6 +15,7 @@ public class ServerSentEventsTransport : IHttpTransport private readonly PipeReader _application; private readonly string _connectionId; private readonly ILogger _logger; + private readonly HttpConnectionContext _connection; public ServerSentEventsTransport(PipeReader application, string connectionId, ILoggerFactory loggerFactory) { @@ -23,6 +24,12 @@ public ServerSentEventsTransport(PipeReader application, string connectionId, IL _logger = loggerFactory.CreateLogger(); } + internal ServerSentEventsTransport(PipeReader application, string connectionId, HttpConnectionContext connection, ILoggerFactory loggerFactory) + : this(application, connectionId, loggerFactory) + { + _connection = connection; + } + public async Task ProcessRequestAsync(HttpContext context, CancellationToken token) { context.Response.ContentType = "text/event-stream"; @@ -52,15 +59,17 @@ public async Task ProcessRequestAsync(HttpContext context, CancellationToken tok { Log.SSEWritingMessage(_logger, buffer.Length); - await ServerSentEventsMessageFormatter.WriteMessageAsync(buffer, context.Response.Body); + _connection?.StartSendCancellation(); + await ServerSentEventsMessageFormatter.WriteMessageAsync(buffer, context.Response.Body, _connection?.SendingToken ?? default); } - else if (result.IsCompleted) + else if (result.IsCompleted || result.IsCanceled) { break; } } finally { + _connection?.StopSendCancellation(); _application.AdvanceTo(buffer.End); } } diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsTransport.cs index e77dbfe102e4..ee183eeefdbd 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsTransport.cs @@ -241,7 +241,8 @@ private async Task StartSending(WebSocket socket) if (WebSocketCanSend(socket)) { - await socket.SendAsync(buffer, webSocketMessageType); + _connection.StartSendCancellation(); + await socket.SendAsync(buffer, webSocketMessageType, _connection.SendingToken); } else { @@ -256,6 +257,10 @@ private async Task StartSending(WebSocket socket) } break; } + finally + { + _connection.StopSendCancellation(); + } } else if (result.IsCompleted) { @@ -283,7 +288,6 @@ private async Task StartSending(WebSocket socket) _application.Input.Complete(); } - } private static bool WebSocketCanSend(WebSocket ws) diff --git a/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj b/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj index b8ded535e0fd..a7a8321eece4 100644 --- a/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj +++ b/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj @@ -1,4 +1,4 @@ - + Components for providing real-time bi-directional communication across the Web. @@ -28,6 +28,7 @@ + diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs index 39d19723bda2..cb80af23c89b 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs @@ -912,6 +912,215 @@ public async Task LongPollingTimeoutSets200StatusCode() } } + private class BlockingStream : Stream + { + private readonly SyncPoint _sync; + private bool _isSSE; + + public BlockingStream(SyncPoint sync, bool isSSE = false) + { + _sync = sync; + _isSSE = isSSE; + } + + public override bool CanRead => throw new NotImplementedException(); + + public override bool CanSeek => throw new NotImplementedException(); + + public override bool CanWrite => throw new NotImplementedException(); + + public override long Length => throw new NotImplementedException(); + + public override long Position { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } + + public override void Flush() + { + } + + public override int Read(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotImplementedException(); + } + + public override void SetLength(long value) + { + throw new NotImplementedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (_isSSE) + { + // SSE does an initial write of :\r\n that we want to ignore in testing + _isSSE = false; + return; + } + await _sync.WaitToContinue(); + cancellationToken.ThrowIfCancellationRequested(); + } + +#if NETCOREAPP2_1 + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + if (_isSSE) + { + // SSE does an initial write of :\r\n that we want to ignore in testing + _isSSE = false; + return; + } + await _sync.WaitToContinue(); + cancellationToken.ThrowIfCancellationRequested(); + } +#endif + } + + [Fact] + public async Task LongPollingConnectionClosesWhenSendTimeoutReached() + { + bool ExpectedErrors(WriteContext writeContext) + { + return (writeContext.LoggerName == typeof(Internal.Transports.LongPollingTransport).FullName && + writeContext.EventId.Name == "LongPollingTerminated") || + (writeContext.LoggerName == typeof(HttpConnectionManager).FullName && writeContext.EventId.Name == "FailedDispose"); + } + + using (StartVerifiableLog(out var loggerFactory, LogLevel.Debug, expectedErrorsFilter: ExpectedErrors)) + { + var manager = CreateConnectionManager(loggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + + var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); + + var context = MakeRequest("/foo", connection); + + var services = new ServiceCollection(); + services.AddSingleton(); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + var options = new HttpConnectionDispatcherOptions(); + // First poll completes immediately + await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); + + var sync = new SyncPoint(); + context.Response.Body = new BlockingStream(sync); + var dispatcherTask = dispatcher.ExecuteAsync(context, options, app); + + await connection.Transport.Output.WriteAsync(new byte[] { 1 }).OrTimeout(); + + await sync.WaitForSyncPoint().OrTimeout(); + // Cancel write to response body + connection.TryCancelSend(long.MaxValue); + sync.Continue(); + + await dispatcherTask.OrTimeout(); + + // Connection should be removed on canceled write + Assert.False(manager.TryGetConnection(connection.ConnectionId, out var _)); + } + } + + [Fact] + public async Task SSEConnectionClosesWhenSendTimeoutReached() + { + using (StartVerifiableLog(out var loggerFactory, LogLevel.Debug)) + { + var manager = CreateConnectionManager(loggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.ServerSentEvents; + + var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); + + var context = MakeRequest("/foo", connection); + SetTransport(context, connection.TransportType); + + var services = new ServiceCollection(); + services.AddSingleton(); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + + var sync = new SyncPoint(); + context.Response.Body = new BlockingStream(sync, isSSE: true); + + var options = new HttpConnectionDispatcherOptions(); + var dispatcherTask = dispatcher.ExecuteAsync(context, options, app); + + await connection.Transport.Output.WriteAsync(new byte[] { 1 }).OrTimeout(); + + await sync.WaitForSyncPoint().OrTimeout(); + // Cancel write to response body + connection.TryCancelSend(long.MaxValue); + sync.Continue(); + + await dispatcherTask.OrTimeout(); + + // Connection should be removed on canceled write + Assert.False(manager.TryGetConnection(connection.ConnectionId, out var _)); + } + } + + [Fact] + public async Task WebSocketConnectionClosesWhenSendTimeoutReached() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(Internal.Transports.WebSocketsTransport).FullName && + writeContext.EventId.Name == "ErrorWritingFrame"; + } + + using (StartVerifiableLog(out var loggerFactory, LogLevel.Debug, expectedErrorsFilter: ExpectedErrors)) + { + var manager = CreateConnectionManager(loggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.WebSockets; + + var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); + + var sync = new SyncPoint(); + var context = MakeRequest("/foo", connection); + SetTransport(context, connection.TransportType, sync); + + var services = new ServiceCollection(); + services.AddSingleton(); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + + var options = new HttpConnectionDispatcherOptions(); + options.WebSockets.CloseTimeout = TimeSpan.FromSeconds(0); + var dispatcherTask = dispatcher.ExecuteAsync(context, options, app); + + await connection.Transport.Output.WriteAsync(new byte[] { 1 }).OrTimeout(); + + await sync.WaitForSyncPoint().OrTimeout(); + // Cancel write to response body + connection.TryCancelSend(long.MaxValue); + sync.Continue(); + + await dispatcherTask.OrTimeout(); + + // Connection should be removed on canceled write + Assert.False(manager.TryGetConnection(connection.ConnectionId, out var _)); + } + } + [Fact] public async Task WebSocketTransportTimesOutWhenCloseFrameNotReceived() { @@ -1719,6 +1928,8 @@ public async Task DeleteEndpointGracefullyTerminatesLongPolling() Assert.Equal(StatusCodes.Status202Accepted, deleteContext.Response.StatusCode); Assert.Equal("text/plain", deleteContext.Response.ContentType); + await connection.DisposeAndRemoveTask.OrTimeout(); + // Verify the connection was removed from the manager Assert.False(manager.TryGetConnection(connection.ConnectionId, out _)); } @@ -1772,6 +1983,110 @@ public async Task DeleteEndpointGracefullyTerminatesLongPollingEvenWhenBetweenPo } } + [Fact] + public async Task DeleteEndpointTerminatesLongPollingWithHangingApplication() + { + using (StartVerifiableLog(out var loggerFactory, LogLevel.Debug)) + { + var manager = CreateConnectionManager(loggerFactory); + var pipeOptions = new PipeOptions(pauseWriterThreshold: 2, resumeWriterThreshold: 1); + var connection = manager.CreateConnection(pipeOptions, pipeOptions); + connection.TransportType = HttpTransportType.LongPolling; + + var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); + + var context = MakeRequest("/foo", connection); + + var services = new ServiceCollection(); + services.AddSingleton(); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + var options = new HttpConnectionDispatcherOptions(); + + var pollTask = dispatcher.ExecuteAsync(context, options, app); + Assert.True(pollTask.IsCompleted); + + // Now send the second poll + pollTask = dispatcher.ExecuteAsync(context, options, app); + + // Issue the delete request and make sure the poll completes + var deleteContext = new DefaultHttpContext(); + deleteContext.Request.Path = "/foo"; + deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionId}"); + deleteContext.Request.Method = "DELETE"; + + Assert.False(pollTask.IsCompleted); + + await dispatcher.ExecuteAsync(deleteContext, options, app).OrTimeout(); + + await pollTask.OrTimeout(); + + // Verify that transport shuts down + await connection.TransportTask.OrTimeout(); + + // Verify the response from the DELETE request + Assert.Equal(StatusCodes.Status202Accepted, deleteContext.Response.StatusCode); + Assert.Equal("text/plain", deleteContext.Response.ContentType); + Assert.Equal(HttpConnectionStatus.Disposed, connection.Status); + + // Verify the connection not removed because application is hanging + Assert.True(manager.TryGetConnection(connection.ConnectionId, out _)); + } + } + + [Fact] + public async Task PollCanReceiveFinalMessageAfterAppCompletes() + { + using (StartVerifiableLog(out var loggerFactory, LogLevel.Debug)) + { + var transportType = HttpTransportType.LongPolling; + var manager = CreateConnectionManager(loggerFactory); + var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = transportType; + + var waitForMessageTcs1 = new TaskCompletionSource(); + var messageTcs1 = new TaskCompletionSource(); + var waitForMessageTcs2 = new TaskCompletionSource(); + var messageTcs2 = new TaskCompletionSource(); + ConnectionDelegate connectionDelegate = async c => + { + await waitForMessageTcs1.Task.OrTimeout(); + await c.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Message1")).OrTimeout(); + messageTcs1.TrySetResult(null); + await waitForMessageTcs2.Task.OrTimeout(); + await c.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Message2")).OrTimeout(); + messageTcs2.TrySetResult(null); + }; + { + var options = new HttpConnectionDispatcherOptions(); + var context = MakeRequest("/foo", connection); + await dispatcher.ExecuteAsync(context, options, connectionDelegate).OrTimeout(); + + // second poll should have data + waitForMessageTcs1.SetResult(null); + await messageTcs1.Task.OrTimeout(); + + var ms = new MemoryStream(); + context.Response.Body = ms; + // Now send the second poll + await dispatcher.ExecuteAsync(context, options, connectionDelegate).OrTimeout(); + Assert.Equal("Message1", Encoding.UTF8.GetString(ms.ToArray())); + + waitForMessageTcs2.SetResult(null); + await messageTcs2.Task.OrTimeout(); + + context = MakeRequest("/foo", connection); + ms.Seek(0, SeekOrigin.Begin); + context.Response.Body = ms; + // This is the third poll which gets the final message after the app is complete + await dispatcher.ExecuteAsync(context, options, connectionDelegate).OrTimeout(); + Assert.Equal("Message2", Encoding.UTF8.GetString(ms.ToArray())); + } + } + } + [Fact] public async Task NegotiateDoesNotReturnWebSocketsWhenNotAvailable() { @@ -2080,12 +2395,12 @@ private static DefaultHttpContext MakeRequest(string path, ConnectionContext con return context; } - private static void SetTransport(HttpContext context, HttpTransportType transportType) + private static void SetTransport(HttpContext context, HttpTransportType transportType, SyncPoint sync = null) { switch (transportType) { case HttpTransportType.WebSockets: - context.Features.Set(new TestWebSocketConnectionFeature()); + context.Features.Set(new TestWebSocketConnectionFeature(sync)); break; case HttpTransportType.ServerSentEvents: context.Request.Headers["Accept"] = "text/event-stream"; diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs index 820d28749d03..720ba58b5ccc 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs @@ -178,7 +178,7 @@ public async Task CloseConnectionsEndsAllPendingConnections() var result = await connection.Application.Input.ReadAsync(); try { - Assert.True(result.IsCompleted); + Assert.True(result.IsCanceled); } finally { diff --git a/src/SignalR/common/Http.Connections/test/TestWebSocketConnectionFeature.cs b/src/SignalR/common/Http.Connections/test/TestWebSocketConnectionFeature.cs index c0711f4acec6..1ca861cbd26c 100644 --- a/src/SignalR/common/Http.Connections/test/TestWebSocketConnectionFeature.cs +++ b/src/SignalR/common/Http.Connections/test/TestWebSocketConnectionFeature.cs @@ -5,11 +5,21 @@ using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.SignalR.Tests; namespace Microsoft.AspNetCore.Http.Connections.Tests { internal class TestWebSocketConnectionFeature : IHttpWebSocketFeature, IDisposable { + public TestWebSocketConnectionFeature() + { } + + public TestWebSocketConnectionFeature(SyncPoint sync) + { + _sync = sync; + } + + private readonly SyncPoint _sync; private readonly TaskCompletionSource _accepted = new TaskCompletionSource(); public bool IsWebSocketRequest => true; @@ -27,8 +37,8 @@ public Task AcceptAsync(WebSocketAcceptContext context) var clientToServer = Channel.CreateUnbounded(); var serverToClient = Channel.CreateUnbounded(); - var clientSocket = new WebSocketChannel(serverToClient.Reader, clientToServer.Writer); - var serverSocket = new WebSocketChannel(clientToServer.Reader, serverToClient.Writer); + var clientSocket = new WebSocketChannel(serverToClient.Reader, clientToServer.Writer, _sync); + var serverSocket = new WebSocketChannel(clientToServer.Reader, serverToClient.Writer, _sync); Client = clientSocket; SubProtocol = context.SubProtocol; @@ -45,16 +55,18 @@ public class WebSocketChannel : WebSocket { private readonly ChannelReader _input; private readonly ChannelWriter _output; + private readonly SyncPoint _sync; private WebSocketCloseStatus? _closeStatus; private string _closeStatusDescription; private WebSocketState _state; private WebSocketMessage _internalBuffer = new WebSocketMessage(); - public WebSocketChannel(ChannelReader input, ChannelWriter output) + public WebSocketChannel(ChannelReader input, ChannelWriter output, SyncPoint sync = null) { _input = input; _output = output; + _sync = sync; } public override WebSocketCloseStatus? CloseStatus => _closeStatus; @@ -173,11 +185,17 @@ public override async Task ReceiveAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + public override async Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { + if (_sync != null) + { + await _sync.WaitToContinue(); + } + cancellationToken.ThrowIfCancellationRequested(); + var copy = new byte[buffer.Count]; Buffer.BlockCopy(buffer.Array, buffer.Offset, copy, 0, buffer.Count); - return SendMessageAsync(new WebSocketMessage + await SendMessageAsync(new WebSocketMessage { Buffer = copy, MessageType = messageType, diff --git a/src/SignalR/common/Shared/PipeWriterStream.cs b/src/SignalR/common/Shared/PipeWriterStream.cs index eb5b6d5addef..245731bfd925 100644 --- a/src/SignalR/common/Shared/PipeWriterStream.cs +++ b/src/SignalR/common/Shared/PipeWriterStream.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -76,7 +76,15 @@ private ValueTask WriteCoreAsync(ReadOnlyMemory source, CancellationToken _length += source.Length; var task = _pipeWriter.WriteAsync(source); - if (!task.IsCompletedSuccessfully) + if (task.IsCompletedSuccessfully) + { + // Cancellation can be triggered by PipeWriter.CancelPendingFlush + if (task.Result.IsCanceled) + { + throw new OperationCanceledException(); + } + } + else if (!task.IsCompletedSuccessfully) { return WriteSlowAsync(task); } diff --git a/src/SignalR/common/testassets/Tests.Utils/TestClient.cs b/src/SignalR/common/testassets/Tests.Utils/TestClient.cs index 534e83cf2b23..39fcd06a3723 100644 --- a/src/SignalR/common/testassets/Tests.Utils/TestClient.cs +++ b/src/SignalR/common/testassets/Tests.Utils/TestClient.cs @@ -32,9 +32,10 @@ public class TestClient : ITransferFormatFeature, IConnectionHeartbeatFeature, I public TransferFormat ActiveFormat { get; set; } - public TestClient(IHubProtocol protocol = null, IInvocationBinder invocationBinder = null, string userIdentifier = null) + public TestClient(IHubProtocol protocol = null, IInvocationBinder invocationBinder = null, string userIdentifier = null, long pauseWriterThreshold = 32768) { - var options = new PipeOptions(readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + var options = new PipeOptions(readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false, + pauseWriterThreshold: pauseWriterThreshold, resumeWriterThreshold: pauseWriterThreshold / 2); var pair = DuplexPipe.CreateConnectionPair(options, options); Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Transport, pair.Application); @@ -65,16 +66,7 @@ public async Task ConnectAsync( { if (sendHandshakeRequestMessage) { - var memoryBufferWriter = MemoryBufferWriter.Get(); - try - { - HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryBufferWriter); - await Connection.Application.Output.WriteAsync(memoryBufferWriter.ToArray()); - } - finally - { - MemoryBufferWriter.Return(memoryBufferWriter); - } + await Connection.Application.Output.WriteAsync(GetHandshakeRequestMessage()); } var connection = handler.OnConnectedAsync(Connection); @@ -290,6 +282,21 @@ public void TickHeartbeat() } } } + + public byte[] GetHandshakeRequestMessage() + { + var memoryBufferWriter = MemoryBufferWriter.Get(); + try + { + HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryBufferWriter); + return memoryBufferWriter.ToArray(); + } + finally + { + MemoryBufferWriter.Return(memoryBufferWriter); + } + } + private class DefaultInvocationBinder : IInvocationBinder { public IReadOnlyList GetParameterTypes(string methodName) diff --git a/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs b/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs index 3c835ab933e2..d9183ae7cc2b 100644 --- a/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs +++ b/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs @@ -82,10 +82,10 @@ public override Task RemoveFromGroupAsync(string connectionId, string groupName, /// public override Task SendAllAsync(string methodName, object[] args, CancellationToken cancellationToken = default) { - return SendToAllConnections(methodName, args, null); + return SendToAllConnections(methodName, args, include: null, cancellationToken); } - private Task SendToAllConnections(string methodName, object[] args, Func include) + private Task SendToAllConnections(string methodName, object[] args, Func include, CancellationToken cancellationToken) { List tasks = null; SerializedHubMessage message = null; @@ -103,7 +103,7 @@ private Task SendToAllConnections(string methodName, object[] args, Func connections, Func include, ref List tasks, ref SerializedHubMessage message) + private void SendToGroupConnections(string methodName, object[] args, ConcurrentDictionary connections, Func include, + ref List tasks, ref SerializedHubMessage message, CancellationToken cancellationToken) { // foreach over ConcurrentDictionary avoids allocating an enumerator foreach (var connection in connections) @@ -142,7 +143,7 @@ private void SendToGroupConnections(string methodName, object[] args, Concurrent message = CreateSerializedInvocationMessage(methodName, args); } - var task = connection.Value.WriteAsync(message); + var task = connection.Value.WriteAsync(message, cancellationToken); if (!task.IsCompletedSuccessfully) { @@ -175,7 +176,7 @@ public override Task SendConnectionAsync(string connectionId, string methodName, // Write message directly to connection without caching it in memory var message = CreateInvocationMessage(methodName, args); - return connection.WriteAsync(message).AsTask(); + return connection.WriteAsync(message, cancellationToken).AsTask(); } /// @@ -193,7 +194,7 @@ public override Task SendGroupAsync(string groupName, string methodName, object[ // group might be modified inbetween checking and sending List tasks = null; SerializedHubMessage message = null; - SendToGroupConnections(methodName, args, group, null, ref tasks, ref message); + SendToGroupConnections(methodName, args, group, include: null, ref tasks, ref message, cancellationToken); if (tasks != null) { @@ -221,7 +222,7 @@ public override Task SendGroupsAsync(IReadOnlyList groupNames, string me var group = _groups[groupName]; if (group != null) { - SendToGroupConnections(methodName, args, group, null, ref tasks, ref message); + SendToGroupConnections(methodName, args, group, include: null, ref tasks, ref message, cancellationToken); } } @@ -247,7 +248,7 @@ public override Task SendGroupExceptAsync(string groupName, string methodName, o List tasks = null; SerializedHubMessage message = null; - SendToGroupConnections(methodName, args, group, connection => !excludedConnectionIds.Contains(connection.ConnectionId), ref tasks, ref message); + SendToGroupConnections(methodName, args, group, connection => !excludedConnectionIds.Contains(connection.ConnectionId), ref tasks, ref message, cancellationToken); if (tasks != null) { @@ -271,7 +272,7 @@ private HubMessage CreateInvocationMessage(string methodName, object[] args) /// public override Task SendUserAsync(string userId, string methodName, object[] args, CancellationToken cancellationToken = default) { - return SendToAllConnections(methodName, args, connection => string.Equals(connection.UserIdentifier, userId, StringComparison.Ordinal)); + return SendToAllConnections(methodName, args, connection => string.Equals(connection.UserIdentifier, userId, StringComparison.Ordinal), cancellationToken); } /// @@ -292,19 +293,19 @@ public override Task OnDisconnectedAsync(HubConnectionContext connection) /// public override Task SendAllExceptAsync(string methodName, object[] args, IReadOnlyList excludedConnectionIds, CancellationToken cancellationToken = default) { - return SendToAllConnections(methodName, args, connection => !excludedConnectionIds.Contains(connection.ConnectionId)); + return SendToAllConnections(methodName, args, connection => !excludedConnectionIds.Contains(connection.ConnectionId), cancellationToken); } /// public override Task SendConnectionsAsync(IReadOnlyList connectionIds, string methodName, object[] args, CancellationToken cancellationToken = default) { - return SendToAllConnections(methodName, args, connection => connectionIds.Contains(connection.ConnectionId)); + return SendToAllConnections(methodName, args, connection => connectionIds.Contains(connection.ConnectionId), cancellationToken); } /// public override Task SendUsersAsync(IReadOnlyList userIds, string methodName, object[] args, CancellationToken cancellationToken = default) { - return SendToAllConnections(methodName, args, connection => userIds.Contains(connection.UserIdentifier)); + return SendToAllConnections(methodName, args, connection => userIds.Contains(connection.UserIdentifier), cancellationToken); } } } diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index 5a1049e78042..8ae666717548 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -33,6 +33,7 @@ public class HubConnectionContext private long _lastSendTimestamp = Stopwatch.GetTimestamp(); private ReadOnlyMemory _cachedPingMessage; + private volatile bool _connectionAborted; /// /// Initializes a new instance of the class. @@ -96,11 +97,17 @@ public virtual ValueTask WriteAsync(HubMessage message, CancellationToken cancel // Try to grab the lock synchronously, if we fail, go to the slower path if (!_writeLock.Wait(0)) { - return new ValueTask(WriteSlowAsync(message)); + return new ValueTask(WriteSlowAsync(message, cancellationToken)); + } + + if (_connectionAborted) + { + _writeLock.Release(); + return default; } // This method should never throw synchronously - var task = WriteCore(message); + var task = WriteCore(message, cancellationToken); // The write didn't complete synchronously so await completion if (!task.IsCompletedSuccessfully) @@ -126,11 +133,17 @@ public virtual ValueTask WriteAsync(SerializedHubMessage message, CancellationTo // Try to grab the lock synchronously, if we fail, go to the slower path if (!_writeLock.Wait(0)) { - return new ValueTask(WriteSlowAsync(message)); + return new ValueTask(WriteSlowAsync(message, cancellationToken)); + } + + if (_connectionAborted) + { + _writeLock.Release(); + return default; } // This method should never throw synchronously - var task = WriteCore(message); + var task = WriteCore(message, cancellationToken); // The write didn't complete synchronously so await completion if (!task.IsCompletedSuccessfully) @@ -144,7 +157,7 @@ public virtual ValueTask WriteAsync(SerializedHubMessage message, CancellationTo return default; } - private ValueTask WriteCore(HubMessage message) + private ValueTask WriteCore(HubMessage message, CancellationToken cancellationToken) { try { @@ -152,29 +165,33 @@ private ValueTask WriteCore(HubMessage message) // write it without caching. Protocol.WriteMessage(message, _connectionContext.Transport.Output); - return _connectionContext.Transport.Output.FlushAsync(); + return _connectionContext.Transport.Output.FlushAsync(cancellationToken); } catch (Exception ex) { Log.FailedWritingMessage(_logger, ex); + Abort(); + return new ValueTask(new FlushResult(isCanceled: false, isCompleted: true)); } } - private ValueTask WriteCore(SerializedHubMessage message) + private ValueTask WriteCore(SerializedHubMessage message, CancellationToken cancellationToken) { try { // Grab a preserialized buffer for this protocol. var buffer = message.GetSerializedMessage(Protocol); - return _connectionContext.Transport.Output.WriteAsync(buffer); + return _connectionContext.Transport.Output.WriteAsync(buffer, cancellationToken); } catch (Exception ex) { Log.FailedWritingMessage(_logger, ex); + Abort(); + return new ValueTask(new FlushResult(isCanceled: false, isCompleted: true)); } } @@ -188,6 +205,8 @@ private async Task CompleteWriteAsync(ValueTask task) catch (Exception ex) { Log.FailedWritingMessage(_logger, ex); + + Abort(); } finally { @@ -196,18 +215,25 @@ private async Task CompleteWriteAsync(ValueTask task) } } - private async Task WriteSlowAsync(HubMessage message) + private async Task WriteSlowAsync(HubMessage message, CancellationToken cancellationToken) { - await _writeLock.WaitAsync(); + // Failed to get the lock immediately when entering WriteAsync so await until it is available + await _writeLock.WaitAsync(cancellationToken); + try { - // Failed to get the lock immediately when entering WriteAsync so await until it is available + if (_connectionAborted) + { + return; + } - await WriteCore(message); + await WriteCore(message, cancellationToken); } catch (Exception ex) { Log.FailedWritingMessage(_logger, ex); + + Abort(); } finally { @@ -215,18 +241,25 @@ private async Task WriteSlowAsync(HubMessage message) } } - private async Task WriteSlowAsync(SerializedHubMessage message) + private async Task WriteSlowAsync(SerializedHubMessage message, CancellationToken cancellationToken) { + // Failed to get the lock immediately when entering WriteAsync so await until it is available + await _writeLock.WaitAsync(cancellationToken); + try { - // Failed to get the lock immediately when entering WriteAsync so await until it is available - await _writeLock.WaitAsync(); + if (_connectionAborted) + { + return; + } - await WriteCore(message); + await WriteCore(message, cancellationToken); } catch (Exception ex) { Log.FailedWritingMessage(_logger, ex); + + Abort(); } finally { @@ -243,6 +276,7 @@ private ValueTask TryWritePingAsync() return default; } + // TODO: cancel? return new ValueTask(TryWritePingSlowAsync()); } @@ -250,6 +284,11 @@ private async Task TryWritePingSlowAsync() { try { + if (_connectionAborted) + { + return; + } + await _connectionContext.Transport.Output.WriteAsync(_cachedPingMessage); Log.SentPing(_logger); @@ -257,6 +296,8 @@ private async Task TryWritePingSlowAsync() catch (Exception ex) { Log.FailedWritingMessage(_logger, ex); + + Abort(); } finally { @@ -293,6 +334,12 @@ private async Task WriteHandshakeResponseAsync(HandshakeResponseMessage message) /// public virtual void Abort() { + _connectionAborted = true; + + // Cancel any current writes or writes that are about to happen and have already gone past the _connectionAborted bool + // We have to do this outside of the lock otherwise it could hang if the write is observing backpressure + _connectionContext.Transport.Output.CancelPendingFlush(); + // If we already triggered the token then noop, this isn't thread safe but it's good enough // to avoid spawning a new task in the most common cases if (_connectionAbortedTokenSource.IsCancellationRequested) @@ -423,9 +470,24 @@ internal void Abort(Exception exception) internal Task AbortAsync() { Abort(); + + // Acquire lock to make sure all writes are completed + if (!_writeLock.Wait(0)) + { + return AbortAsyncSlow(); + } + + _writeLock.Release(); return _abortCompletedTcs.Task; } + private async Task AbortAsyncSlow() + { + await _writeLock.WaitAsync(); + _writeLock.Release(); + await _abortCompletedTcs.Task; + } + private void KeepAliveTick() { var timestamp = Stopwatch.GetTimestamp(); @@ -449,12 +511,10 @@ private void KeepAliveTick() private static void AbortConnection(object state) { var connection = (HubConnectionContext)state; + try { connection._connectionAbortedTokenSource.Cancel(); - - // Communicate the fact that we're finished triggering abort callbacks - connection._abortCompletedTcs.TrySetResult(null); } catch (Exception ex) { @@ -462,6 +522,26 @@ private static void AbortConnection(object state) // we don't end up with an unobserved task connection._abortCompletedTcs.TrySetException(ex); } + finally + { + _ = InnerAbortConnection(connection); + } + } + + private static async Task InnerAbortConnection(HubConnectionContext connection) + { + // We lock to make sure all writes are done before triggering the completion of the pipe + await connection._writeLock.WaitAsync(); + try + { + // Communicate the fact that we're finished triggering abort callbacks + // HubOnDisconnectedAsync is waiting on this to complete the Pipe + connection._abortCompletedTcs.TrySetResult(null); + } + finally + { + connection._writeLock.Release(); + } } private static class Log diff --git a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs index d2694b17b3e9..1b604368741c 100644 --- a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs +++ b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Collections.Generic; diff --git a/src/SignalR/server/Core/src/Internal/Proxies.cs b/src/SignalR/server/Core/src/Internal/Proxies.cs index 9a3edd56bdb5..8a2beb26de03 100644 --- a/src/SignalR/server/Core/src/Internal/Proxies.cs +++ b/src/SignalR/server/Core/src/Internal/Proxies.cs @@ -105,7 +105,7 @@ public AllClientProxy(HubLifetimeManager lifetimeManager) public Task SendCoreAsync(string method, object[] args, CancellationToken cancellationToken = default) { - return _lifetimeManager.SendAllAsync(method, args); + return _lifetimeManager.SendAllAsync(method, args, cancellationToken); } } diff --git a/src/SignalR/server/SignalR/test/DefaultHubLifetimeManagerTests.cs b/src/SignalR/server/SignalR/test/DefaultHubLifetimeManagerTests.cs index 6fcce1872dbd..f911ca1382fa 100644 --- a/src/SignalR/server/SignalR/test/DefaultHubLifetimeManagerTests.cs +++ b/src/SignalR/server/SignalR/test/DefaultHubLifetimeManagerTests.cs @@ -1,3 +1,8 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Protocol; @@ -155,6 +160,289 @@ public async Task RemoveGroupOnNonExistentConnectionNoops() await manager.RemoveFromGroupAsync("NotARealConnectionId", "MyGroup").OrTimeout(); } + [Fact] + public async Task SendAllAsyncWillCancelWithToken() + { + using (var client1 = new TestClient()) + using (var client2 = new TestClient(pauseWriterThreshold: 2)) + { + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); + + await manager.OnConnectedAsync(connection1).OrTimeout(); + await manager.OnConnectedAsync(connection2).OrTimeout(); + + var cts = new CancellationTokenSource(); + var sendTask = manager.SendAllAsync("Hello", new object[] { "World" }, cts.Token).OrTimeout(); + + Assert.False(sendTask.IsCompleted); + cts.Cancel(); + await sendTask.OrTimeout(); + + var message = Assert.IsType(client1.TryRead()); + Assert.Equal("Hello", message.Target); + Assert.Single(message.Arguments); + Assert.Equal("World", (string)message.Arguments[0]); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection2.ConnectionAborted.Register(t => + { + ((TaskCompletionSource)t).SetResult(null); + }, tcs); + await tcs.Task.OrTimeout(); + + Assert.False(connection1.ConnectionAborted.IsCancellationRequested); + } + } + + [Fact] + public async Task SendAllExceptAsyncWillCancelWithToken() + { + using (var client1 = new TestClient()) + using (var client2 = new TestClient(pauseWriterThreshold: 2)) + { + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); + + await manager.OnConnectedAsync(connection1).OrTimeout(); + await manager.OnConnectedAsync(connection2).OrTimeout(); + + var cts = new CancellationTokenSource(); + var sendTask = manager.SendAllExceptAsync("Hello", new object[] { "World" }, new List { connection1.ConnectionId }, cts.Token).OrTimeout(); + + Assert.False(sendTask.IsCompleted); + cts.Cancel(); + await sendTask.OrTimeout(); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection2.ConnectionAborted.Register(t => + { + ((TaskCompletionSource)t).SetResult(null); + }, tcs); + await tcs.Task.OrTimeout(); + + Assert.False(connection1.ConnectionAborted.IsCancellationRequested); + Assert.Null(client1.TryRead()); + } + } + + [Fact] + public async Task SendConnectionAsyncWillCancelWithToken() + { + using (var client1 = new TestClient(pauseWriterThreshold: 2)) + { + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager.OnConnectedAsync(connection1).OrTimeout(); + + var cts = new CancellationTokenSource(); + var sendTask = manager.SendConnectionAsync(connection1.ConnectionId, "Hello", new object[] { "World" }, cts.Token).OrTimeout(); + + Assert.False(sendTask.IsCompleted); + cts.Cancel(); + await sendTask.OrTimeout(); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection1.ConnectionAborted.Register(t => + { + ((TaskCompletionSource)t).SetResult(null); + }, tcs); + await tcs.Task.OrTimeout(); + } + } + + [Fact] + public async Task SendConnectionsAsyncWillCancelWithToken() + { + using (var client1 = new TestClient(pauseWriterThreshold: 2)) + { + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager.OnConnectedAsync(connection1).OrTimeout(); + + var cts = new CancellationTokenSource(); + var sendTask = manager.SendConnectionsAsync(new List { connection1.ConnectionId }, "Hello", new object[] { "World" }, cts.Token).OrTimeout(); + + Assert.False(sendTask.IsCompleted); + cts.Cancel(); + await sendTask.OrTimeout(); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection1.ConnectionAborted.Register(t => + { + ((TaskCompletionSource)t).SetResult(null); + }, tcs); + await tcs.Task.OrTimeout(); + } + } + + [Fact] + public async Task SendGroupAsyncWillCancelWithToken() + { + using (var client1 = new TestClient(pauseWriterThreshold: 2)) + { + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager.OnConnectedAsync(connection1).OrTimeout(); + + await manager.AddToGroupAsync(connection1.ConnectionId, "group").OrTimeout(); + + var cts = new CancellationTokenSource(); + var sendTask = manager.SendGroupAsync("group", "Hello", new object[] { "World" }, cts.Token).OrTimeout(); + + Assert.False(sendTask.IsCompleted); + cts.Cancel(); + await sendTask.OrTimeout(); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection1.ConnectionAborted.Register(t => + { + ((TaskCompletionSource)t).SetResult(null); + }, tcs); + await tcs.Task.OrTimeout(); + } + } + + [Fact] + public async Task SendGroupExceptAsyncWillCancelWithToken() + { + using (var client1 = new TestClient()) + using (var client2 = new TestClient(pauseWriterThreshold: 2)) + { + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); + + await manager.OnConnectedAsync(connection1).OrTimeout(); + await manager.OnConnectedAsync(connection2).OrTimeout(); + + await manager.AddToGroupAsync(connection1.ConnectionId, "group").OrTimeout(); + await manager.AddToGroupAsync(connection2.ConnectionId, "group").OrTimeout(); + + var cts = new CancellationTokenSource(); + var sendTask = manager.SendGroupExceptAsync("group", "Hello", new object[] { "World" }, new List { connection1.ConnectionId }, cts.Token).OrTimeout(); + + Assert.False(sendTask.IsCompleted); + cts.Cancel(); + await sendTask.OrTimeout(); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection2.ConnectionAborted.Register(t => + { + ((TaskCompletionSource)t).SetResult(null); + }, tcs); + await tcs.Task.OrTimeout(); + + Assert.False(connection1.ConnectionAborted.IsCancellationRequested); + Assert.Null(client1.TryRead()); + } + } + + [Fact] + public async Task SendGroupsAsyncWillCancelWithToken() + { + using (var client1 = new TestClient(pauseWriterThreshold: 2)) + { + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager.OnConnectedAsync(connection1).OrTimeout(); + + await manager.AddToGroupAsync(connection1.ConnectionId, "group").OrTimeout(); + + var cts = new CancellationTokenSource(); + var sendTask = manager.SendGroupsAsync(new List { "group" }, "Hello", new object[] { "World" }, cts.Token).OrTimeout(); + + Assert.False(sendTask.IsCompleted); + cts.Cancel(); + await sendTask.OrTimeout(); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection1.ConnectionAborted.Register(t => + { + ((TaskCompletionSource)t).SetResult(null); + }, tcs); + await tcs.Task.OrTimeout(); + } + } + + [Fact] + public async Task SendUserAsyncWillCancelWithToken() + { + using (var client1 = new TestClient()) + using (var client2 = new TestClient(pauseWriterThreshold: 2)) + { + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); + var connection1 = HubConnectionContextUtils.Create(client1.Connection, userIdentifier: "user"); + var connection2 = HubConnectionContextUtils.Create(client2.Connection, userIdentifier: "user"); + + await manager.OnConnectedAsync(connection1).OrTimeout(); + await manager.OnConnectedAsync(connection2).OrTimeout(); + + var cts = new CancellationTokenSource(); + var sendTask = manager.SendUserAsync("user", "Hello", new object[] { "World" }, cts.Token).OrTimeout(); + + Assert.False(sendTask.IsCompleted); + cts.Cancel(); + await sendTask.OrTimeout(); + + var message = Assert.IsType(client1.TryRead()); + Assert.Equal("Hello", message.Target); + Assert.Single(message.Arguments); + Assert.Equal("World", (string)message.Arguments[0]); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection2.ConnectionAborted.Register(t => + { + ((TaskCompletionSource)t).SetResult(null); + }, tcs); + await tcs.Task.OrTimeout(); + + Assert.False(connection1.ConnectionAborted.IsCancellationRequested); + } + } + + [Fact] + public async Task SendUsersAsyncWillCancelWithToken() + { + using (var client1 = new TestClient()) + using (var client2 = new TestClient(pauseWriterThreshold: 2)) + { + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); + var connection1 = HubConnectionContextUtils.Create(client1.Connection, userIdentifier: "user1"); + var connection2 = HubConnectionContextUtils.Create(client2.Connection, userIdentifier: "user2"); + + await manager.OnConnectedAsync(connection1).OrTimeout(); + await manager.OnConnectedAsync(connection2).OrTimeout(); + + var cts = new CancellationTokenSource(); + var sendTask = manager.SendUsersAsync(new List { "user1", "user2" }, "Hello", new object[] { "World" }, cts.Token).OrTimeout(); + + Assert.False(sendTask.IsCompleted); + cts.Cancel(); + await sendTask.OrTimeout(); + + var message = Assert.IsType(client1.TryRead()); + Assert.Equal("Hello", message.Target); + Assert.Single(message.Arguments); + Assert.Equal("World", (string)message.Arguments[0]); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection2.ConnectionAborted.Register(t => + { + ((TaskCompletionSource)t).SetResult(null); + }, tcs); + await tcs.Task.OrTimeout(); + + Assert.False(connection1.ConnectionAborted.IsCancellationRequested); + } + } + private class MyHub : Hub { diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index 06fa3841eb9c..c3100870953b 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -79,9 +79,11 @@ public async Task AbortFromHubMethodForcesClientDisconnect() { var connectionHandlerTask = await client.ConnectAsync(connectionHandler); - await client.InvokeAsync(nameof(AbortHub.Kill)); + await client.SendInvocationAsync(nameof(AbortHub.Kill)); await connectionHandlerTask.OrTimeout(); + + Assert.Null(client.TryRead()); } }