diff options
Diffstat (limited to 'SocketHttpListener/WebSocket.cs')
| -rw-r--r-- | SocketHttpListener/WebSocket.cs | 294 |
1 files changed, 148 insertions, 146 deletions
diff --git a/SocketHttpListener/WebSocket.cs b/SocketHttpListener/WebSocket.cs index 128bc8b97..0dcb6a64b 100644 --- a/SocketHttpListener/WebSocket.cs +++ b/SocketHttpListener/WebSocket.cs @@ -30,9 +30,9 @@ namespace SocketHttpListener private CookieCollection _cookies; private AutoResetEvent _exitReceiving; private object _forConn; - private object _forEvent; + private readonly SemaphoreSlim _forEvent = new SemaphoreSlim(1, 1); private object _forMessageEventQueue; - private object _forSend; + private readonly SemaphoreSlim _forSend = new SemaphoreSlim(1, 1); private const string _guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; private Queue<MessageEventArgs> _messageEventQueue; private string _protocol; @@ -109,12 +109,15 @@ namespace SocketHttpListener #region Private Methods - private void close(CloseStatusCode code, string reason, bool wait) + private async Task CloseAsync(CloseStatusCode code, string reason, bool wait) { - close(new PayloadData(((ushort)code).Append(reason)), !code.IsReserved(), wait); + await CloseAsync(new PayloadData( + await ((ushort)code).AppendAsync(reason).ConfigureAwait(false)), + !code.IsReserved(), + wait).ConfigureAwait(false); } - private void close(PayloadData payload, bool send, bool wait) + private async Task CloseAsync(PayloadData payload, bool send, bool wait) { lock (_forConn) { @@ -126,11 +129,12 @@ namespace SocketHttpListener _readyState = WebSocketState.CloseSent; } - var e = new CloseEventArgs(payload); - e.WasClean = - closeHandshake( + var e = new CloseEventArgs(payload) + { + WasClean = await CloseHandshakeAsync( send ? WebSocketFrame.CreateCloseFrame(Mask.Unmask, payload).ToByteArray() : null, - wait ? 1000 : 0); + wait ? 1000 : 0).ConfigureAwait(false) + }; _readyState = WebSocketState.Closed; try @@ -143,9 +147,9 @@ namespace SocketHttpListener } } - private bool closeHandshake(byte[] frameAsBytes, int millisecondsTimeout) + private async Task<bool> CloseHandshakeAsync(byte[] frameAsBytes, int millisecondsTimeout) { - var sent = frameAsBytes != null && writeBytes(frameAsBytes); + var sent = frameAsBytes != null && await WriteBytesAsync(frameAsBytes).ConfigureAwait(false); var received = millisecondsTimeout == 0 || (sent && _exitReceiving != null && _exitReceiving.WaitOne(millisecondsTimeout)); @@ -189,11 +193,11 @@ namespace SocketHttpListener _context = null; } - private bool concatenateFragmentsInto(Stream dest) + private async Task<bool> ConcatenateFragmentsIntoAsync(Stream dest) { while (true) { - var frame = WebSocketFrame.Read(_stream, true); + var frame = await WebSocketFrame.ReadAsync(_stream, true).ConfigureAwait(false); if (frame.IsFinal) { /* FINAL */ @@ -221,7 +225,7 @@ namespace SocketHttpListener // CLOSE if (frame.IsClose) - return processCloseFrame(frame); + return await ProcessCloseFrameAsync(frame).ConfigureAwait(false); } else { @@ -236,10 +240,10 @@ namespace SocketHttpListener } // ? - return processUnsupportedFrame( + return await ProcessUnsupportedFrameAsync( frame, CloseStatusCode.IncorrectData, - "An incorrect data has been received while receiving fragmented data."); + "An incorrect data has been received while receiving fragmented data.").ConfigureAwait(false); } return true; @@ -299,44 +303,42 @@ namespace SocketHttpListener _compression = CompressionMethod.None; _cookies = new CookieCollection(); _forConn = new object(); - _forEvent = new object(); - _forSend = new object(); _messageEventQueue = new Queue<MessageEventArgs>(); _forMessageEventQueue = ((ICollection)_messageEventQueue).SyncRoot; _readyState = WebSocketState.Connecting; } - private void open() + private async Task OpenAsync() { try { startReceiving(); - lock (_forEvent) - { - try - { - if (OnOpen != null) - { - OnOpen(this, EventArgs.Empty); - } - } - catch (Exception ex) - { - processException(ex, "An exception has occurred while OnOpen."); - } - } } catch (Exception ex) { - processException(ex, "An exception has occurred while opening."); + await ProcessExceptionAsync(ex, "An exception has occurred while opening.").ConfigureAwait(false); + } + + await _forEvent.WaitAsync().ConfigureAwait(false); + try + { + OnOpen?.Invoke(this, EventArgs.Empty); + } + catch (Exception ex) + { + await ProcessExceptionAsync(ex, "An exception has occurred while OnOpen.").ConfigureAwait(false); + } + finally + { + _forEvent.Release(); } } - private bool processCloseFrame(WebSocketFrame frame) + private async Task<bool> ProcessCloseFrameAsync(WebSocketFrame frame) { var payload = frame.PayloadData; - close(payload, !payload.ContainsReservedCloseStatusCode, false); + await CloseAsync(payload, !payload.ContainsReservedCloseStatusCode, false).ConfigureAwait(false); return false; } @@ -352,7 +354,7 @@ namespace SocketHttpListener return true; } - private void processException(Exception exception, string message) + private async Task ProcessExceptionAsync(Exception exception, string message) { var code = CloseStatusCode.Abnormal; var reason = message; @@ -365,25 +367,31 @@ namespace SocketHttpListener error(message ?? code.GetMessage(), exception); if (_readyState == WebSocketState.Connecting) - Close(HttpStatusCode.BadRequest); + { + await CloseAsync(HttpStatusCode.BadRequest).ConfigureAwait(false); + } else - close(code, reason ?? code.GetMessage(), false); + { + await CloseAsync(code, reason ?? code.GetMessage(), false).ConfigureAwait(false); + } } - private bool processFragmentedFrame(WebSocketFrame frame) + private Task<bool> ProcessFragmentedFrameAsync(WebSocketFrame frame) { return frame.IsContinuation // Not first fragment - ? true - : processFragments(frame); + ? Task.FromResult(true) + : ProcessFragmentsAsync(frame); } - private bool processFragments(WebSocketFrame first) + private async Task<bool> ProcessFragmentsAsync(WebSocketFrame first) { using (var buff = new MemoryStream()) { buff.WriteBytes(first.PayloadData.ApplicationData); - if (!concatenateFragmentsInto(buff)) + if (!await ConcatenateFragmentsIntoAsync(buff).ConfigureAwait(false)) + { return false; + } byte[] data; if (_compression != CompressionMethod.None) @@ -412,36 +420,38 @@ namespace SocketHttpListener return true; } - private bool processUnsupportedFrame(WebSocketFrame frame, CloseStatusCode code, string reason) + private async Task<bool> ProcessUnsupportedFrameAsync(WebSocketFrame frame, CloseStatusCode code, string reason) { - processException(new WebSocketException(code, reason), null); + await ProcessExceptionAsync(new WebSocketException(code, reason), null).ConfigureAwait(false); return false; } - private bool processWebSocketFrame(WebSocketFrame frame) + private Task<bool> ProcessWebSocketFrameAsync(WebSocketFrame frame) { + // TODO: @bond change to if/else chain return frame.IsCompressed && _compression == CompressionMethod.None - ? processUnsupportedFrame( + ? ProcessUnsupportedFrameAsync( frame, CloseStatusCode.IncorrectData, "A compressed data has been received without available decompression method.") : frame.IsFragmented - ? processFragmentedFrame(frame) + ? ProcessFragmentedFrameAsync(frame) : frame.IsData - ? processDataFrame(frame) + ? Task.FromResult(processDataFrame(frame)) : frame.IsPing - ? processPingFrame(frame) + ? Task.FromResult(processPingFrame(frame)) : frame.IsPong - ? processPongFrame(frame) + ? Task.FromResult(processPongFrame(frame)) : frame.IsClose - ? processCloseFrame(frame) - : processUnsupportedFrame(frame, CloseStatusCode.PolicyViolation, null); + ? ProcessCloseFrameAsync(frame) + : ProcessUnsupportedFrameAsync(frame, CloseStatusCode.PolicyViolation, null); } - private bool send(Opcode opcode, Stream stream) + private async Task<bool> SendAsync(Opcode opcode, Stream stream) { - lock (_forSend) + await _forSend.WaitAsync().ConfigureAwait(false); + try { var src = stream; var compressed = false; @@ -454,7 +464,7 @@ namespace SocketHttpListener compressed = true; } - sent = send(opcode, Mask.Unmask, stream, compressed); + sent = await SendAsync(opcode, Mask.Unmask, stream, compressed).ConfigureAwait(false); if (!sent) error("Sending a data has been interrupted."); } @@ -472,16 +482,20 @@ namespace SocketHttpListener return sent; } + finally + { + _forSend.Release(); + } } - private bool send(Opcode opcode, Mask mask, Stream stream, bool compressed) + private async Task<bool> SendAsync(Opcode opcode, Mask mask, Stream stream, bool compressed) { var len = stream.Length; /* Not fragmented */ if (len == 0) - return send(Fin.Final, opcode, mask, new byte[0], compressed); + return await SendAsync(Fin.Final, opcode, mask, new byte[0], compressed).ConfigureAwait(false); var quo = len / FragmentLength; var rem = (int)(len % FragmentLength); @@ -490,26 +504,26 @@ namespace SocketHttpListener if (quo == 0) { buff = new byte[rem]; - return stream.Read(buff, 0, rem) == rem && - send(Fin.Final, opcode, mask, buff, compressed); + return await stream.ReadAsync(buff, 0, rem).ConfigureAwait(false) == rem && + await SendAsync(Fin.Final, opcode, mask, buff, compressed).ConfigureAwait(false); } buff = new byte[FragmentLength]; if (quo == 1 && rem == 0) - return stream.Read(buff, 0, FragmentLength) == FragmentLength && - send(Fin.Final, opcode, mask, buff, compressed); + return await stream.ReadAsync(buff, 0, FragmentLength).ConfigureAwait(false) == FragmentLength && + await SendAsync(Fin.Final, opcode, mask, buff, compressed).ConfigureAwait(false); /* Send fragmented */ // Begin - if (stream.Read(buff, 0, FragmentLength) != FragmentLength || - !send(Fin.More, opcode, mask, buff, compressed)) + if (await stream.ReadAsync(buff, 0, FragmentLength).ConfigureAwait(false) != FragmentLength || + !await SendAsync(Fin.More, opcode, mask, buff, compressed).ConfigureAwait(false)) return false; var n = rem == 0 ? quo - 2 : quo - 1; for (long i = 0; i < n; i++) - if (stream.Read(buff, 0, FragmentLength) != FragmentLength || - !send(Fin.More, Opcode.Cont, mask, buff, compressed)) + if (await stream.ReadAsync(buff, 0, FragmentLength).ConfigureAwait(false) != FragmentLength || + !await SendAsync(Fin.More, Opcode.Cont, mask, buff, compressed).ConfigureAwait(false)) return false; // End @@ -518,98 +532,88 @@ namespace SocketHttpListener else buff = new byte[rem]; - return stream.Read(buff, 0, rem) == rem && - send(Fin.Final, Opcode.Cont, mask, buff, compressed); + return await stream.ReadAsync(buff, 0, rem).ConfigureAwait(false) == rem && + await SendAsync(Fin.Final, Opcode.Cont, mask, buff, compressed).ConfigureAwait(false); } - private bool send(Fin fin, Opcode opcode, Mask mask, byte[] data, bool compressed) + private Task<bool> SendAsync(Fin fin, Opcode opcode, Mask mask, byte[] data, bool compressed) { lock (_forConn) { if (_readyState != WebSocketState.Open) { - return false; + return Task.FromResult(false); } - return writeBytes( + return WriteBytesAsync( WebSocketFrame.CreateWebSocketFrame(fin, opcode, mask, data, compressed).ToByteArray()); } } - private Task sendAsync(Opcode opcode, Stream stream) - { - var completionSource = new TaskCompletionSource<bool>(); - Task.Run(() => - { - try - { - send(opcode, stream); - completionSource.TrySetResult(true); - } - catch (Exception ex) - { - completionSource.TrySetException(ex); - } - }); - return completionSource.Task; - } - // As server - private bool sendHttpResponse(HttpResponse response) - { - return writeBytes(response.ToByteArray()); - } + private Task<bool> SendHttpResponseAsync(HttpResponse response) + => WriteBytesAsync(response.ToByteArray()); private void startReceiving() { if (_messageEventQueue.Count > 0) + { _messageEventQueue.Clear(); + } _exitReceiving = new AutoResetEvent(false); _receivePong = new AutoResetEvent(false); Action receive = null; - receive = () => WebSocketFrame.ReadAsync( - _stream, - true, - frame => - { - if (processWebSocketFrame(frame) && _readyState != WebSocketState.Closed) - { - receive(); - - if (!frame.IsData) - return; - - lock (_forEvent) - { - try - { - var e = dequeueFromMessageEventQueue(); - if (e != null && _readyState == WebSocketState.Open) - OnMessage.Emit(this, e); - } - catch (Exception ex) - { - processException(ex, "An exception has occurred while OnMessage."); - } - } - } - else if (_exitReceiving != null) - { - _exitReceiving.Set(); - } - }, - ex => processException(ex, "An exception has occurred while receiving a message.")); + receive = async () => await WebSocketFrame.ReadAsync( + _stream, + true, + async frame => + { + if (await ProcessWebSocketFrameAsync(frame).ConfigureAwait(false) && _readyState != WebSocketState.Closed) + { + receive(); + + if (!frame.IsData) + { + return; + } + + await _forEvent.WaitAsync().ConfigureAwait(false); + + try + { + var e = dequeueFromMessageEventQueue(); + if (e != null && _readyState == WebSocketState.Open) + { + OnMessage.Emit(this, e); + } + } + catch (Exception ex) + { + await ProcessExceptionAsync(ex, "An exception has occurred while OnMessage.").ConfigureAwait(false); + } + finally + { + _forEvent.Release(); + } + + } + else if (_exitReceiving != null) + { + _exitReceiving.Set(); + } + }, + async ex => await ProcessExceptionAsync(ex, "An exception has occurred while receiving a message.")).ConfigureAwait(false); receive(); } - private bool writeBytes(byte[] data) + private async Task<bool> WriteBytesAsync(byte[] data) { try { - _stream.Write(data, 0, data.Length); + await _stream.WriteAsync(data, 0, data.Length).ConfigureAwait(false); return true; } catch (Exception) @@ -623,10 +627,10 @@ namespace SocketHttpListener #region Internal Methods // As server - internal void Close(HttpResponse response) + internal async Task CloseAsync(HttpResponse response) { _readyState = WebSocketState.CloseSent; - sendHttpResponse(response); + await SendHttpResponseAsync(response).ConfigureAwait(false); closeServerResources(); @@ -634,22 +638,20 @@ namespace SocketHttpListener } // As server - internal void Close(HttpStatusCode code) - { - Close(createHandshakeCloseResponse(code)); - } + internal Task CloseAsync(HttpStatusCode code) + => CloseAsync(createHandshakeCloseResponse(code)); // As server - public void ConnectAsServer() + public async Task ConnectAsServer() { try { _readyState = WebSocketState.Open; - open(); + await OpenAsync().ConfigureAwait(false); } catch (Exception ex) { - processException(ex, "An exception has occurred while connecting."); + await ProcessExceptionAsync(ex, "An exception has occurred while connecting.").ConfigureAwait(false); } } @@ -660,18 +662,18 @@ namespace SocketHttpListener /// <summary> /// Closes the WebSocket connection, and releases all associated resources. /// </summary> - public void Close() + public Task CloseAsync() { var msg = _readyState.CheckIfClosable(); if (msg != null) { error(msg); - return; + return Task.CompletedTask; } var send = _readyState == WebSocketState.Open; - close(new PayloadData(), send, send); + return CloseAsync(new PayloadData(), send, send); } /// <summary> @@ -689,11 +691,11 @@ namespace SocketHttpListener /// <param name="reason"> /// A <see cref="string"/> that represents the reason for the close. /// </param> - public void Close(CloseStatusCode code, string reason) + public async Task CloseAsync(CloseStatusCode code, string reason) { byte[] data = null; var msg = _readyState.CheckIfClosable() ?? - (data = ((ushort)code).Append(reason)).CheckIfValidControlData("reason"); + (data = await ((ushort)code).AppendAsync(reason).ConfigureAwait(false)).CheckIfValidControlData("reason"); if (msg != null) { @@ -703,7 +705,7 @@ namespace SocketHttpListener } var send = _readyState == WebSocketState.Open && !code.IsReserved(); - close(new PayloadData(data), send, send); + await CloseAsync(new PayloadData(data), send, send).ConfigureAwait(false); } /// <summary> @@ -728,7 +730,7 @@ namespace SocketHttpListener throw new Exception(msg); } - return sendAsync(Opcode.Binary, new MemoryStream(data)); + return SendAsync(Opcode.Binary, new MemoryStream(data)); } /// <summary> @@ -753,7 +755,7 @@ namespace SocketHttpListener throw new Exception(msg); } - return sendAsync(Opcode.Text, new MemoryStream(Encoding.UTF8.GetBytes(data))); + return SendAsync(Opcode.Text, new MemoryStream(Encoding.UTF8.GetBytes(data))); } #endregion @@ -768,7 +770,7 @@ namespace SocketHttpListener /// </remarks> void IDisposable.Dispose() { - Close(CloseStatusCode.Away, null); + CloseAsync(CloseStatusCode.Away, null).GetAwaiter().GetResult(); } #endregion |
