aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVasily <JustAMan@users.noreply.github.com>2019-02-20 15:03:42 +0300
committerGitHub <noreply@github.com>2019-02-20 15:03:42 +0300
commit8ef41020d94832fbce4a3ef8aba598d30b7adaa5 (patch)
treedca4bf1b09c31b8f5bfd4e2b5815d5ede859179e
parent60df855b263e691f946973a192621e7998db9cbb (diff)
parentfca226bdfde49f30e6347593a9d8870eec55269f (diff)
Merge pull request #847 from Bond-009/async
Make websockets code async
-rw-r--r--Jellyfin.Server/SocketSharp/SharpWebSocket.cs7
-rw-r--r--Jellyfin.Server/SocketSharp/WebSocketSharpListener.cs26
-rw-r--r--SocketHttpListener/Ext.cs69
-rw-r--r--SocketHttpListener/Net/HttpListener.cs43
-rw-r--r--SocketHttpListener/Net/HttpListenerPrefixCollection.cs79
-rw-r--r--SocketHttpListener/WebSocket.cs294
-rw-r--r--SocketHttpListener/WebSocketFrame.cs47
7 files changed, 290 insertions, 275 deletions
diff --git a/Jellyfin.Server/SocketSharp/SharpWebSocket.cs b/Jellyfin.Server/SocketSharp/SharpWebSocket.cs
index 6eee4cd12..9b0951857 100644
--- a/Jellyfin.Server/SocketSharp/SharpWebSocket.cs
+++ b/Jellyfin.Server/SocketSharp/SharpWebSocket.cs
@@ -44,10 +44,11 @@ namespace Jellyfin.Server.SocketSharp
socket.OnMessage += OnSocketMessage;
socket.OnClose += OnSocketClose;
socket.OnError += OnSocketError;
-
- WebSocket.ConnectAsServer();
}
+ public Task ConnectAsServerAsync()
+ => WebSocket.ConnectAsServer();
+
public Task StartReceive()
{
return _taskCompletionSource.Task;
@@ -133,7 +134,7 @@ namespace Jellyfin.Server.SocketSharp
_cancellationTokenSource.Cancel();
- WebSocket.Close();
+ WebSocket.CloseAsync().GetAwaiter().GetResult();
}
_disposed = true;
diff --git a/Jellyfin.Server/SocketSharp/WebSocketSharpListener.cs b/Jellyfin.Server/SocketSharp/WebSocketSharpListener.cs
index 58c4d38a2..736f9feef 100644
--- a/Jellyfin.Server/SocketSharp/WebSocketSharpListener.cs
+++ b/Jellyfin.Server/SocketSharp/WebSocketSharpListener.cs
@@ -69,7 +69,7 @@ namespace Jellyfin.Server.SocketSharp
{
if (_listener == null)
{
- _listener = new HttpListener(_logger, _cryptoProvider, _socketFactory, _networkManager, _streamHelper, _fileSystem, _environment);
+ _listener = new HttpListener(_logger, _cryptoProvider, _socketFactory, _streamHelper, _fileSystem, _environment);
}
_listener.EnableDualMode = _enableDualMode;
@@ -79,22 +79,14 @@ namespace Jellyfin.Server.SocketSharp
_listener.LoadCert(_certificate);
}
- foreach (var prefix in urlPrefixes)
- {
- _logger.LogInformation("Adding HttpListener prefix " + prefix);
- _listener.Prefixes.Add(prefix);
- }
+ _logger.LogInformation("Adding HttpListener prefixes {Prefixes}", urlPrefixes);
+ _listener.Prefixes.AddRange(urlPrefixes);
- _listener.OnContext = ProcessContext;
+ _listener.OnContext = async c => await InitTask(c, _disposeCancellationToken).ConfigureAwait(false);
_listener.Start();
}
- private void ProcessContext(HttpListenerContext context)
- {
- _ = Task.Run(async () => await InitTask(context, _disposeCancellationToken).ConfigureAwait(false));
- }
-
private static void LogRequest(ILogger logger, HttpListenerRequest request)
{
var url = request.Url.ToString();
@@ -151,10 +143,7 @@ namespace Jellyfin.Server.SocketSharp
Endpoint = endpoint
};
- if (WebSocketConnecting != null)
- {
- WebSocketConnecting(connectingArgs);
- }
+ WebSocketConnecting?.Invoke(connectingArgs);
if (connectingArgs.AllowConnection)
{
@@ -165,6 +154,7 @@ namespace Jellyfin.Server.SocketSharp
if (WebSocketConnected != null)
{
var socket = new SharpWebSocket(webSocketContext.WebSocket, _logger);
+ await socket.ConnectAsServerAsync().ConfigureAwait(false);
WebSocketConnected(new WebSocketConnectEventArgs
{
@@ -174,7 +164,7 @@ namespace Jellyfin.Server.SocketSharp
Endpoint = endpoint
});
- await ReceiveWebSocket(ctx, socket).ConfigureAwait(false);
+ await ReceiveWebSocketAsync(ctx, socket).ConfigureAwait(false);
}
}
else
@@ -192,7 +182,7 @@ namespace Jellyfin.Server.SocketSharp
}
}
- private async Task ReceiveWebSocket(HttpListenerContext ctx, SharpWebSocket socket)
+ private async Task ReceiveWebSocketAsync(HttpListenerContext ctx, SharpWebSocket socket)
{
try
{
diff --git a/SocketHttpListener/Ext.cs b/SocketHttpListener/Ext.cs
index a02b48061..2b3c67071 100644
--- a/SocketHttpListener/Ext.cs
+++ b/SocketHttpListener/Ext.cs
@@ -74,18 +74,20 @@ namespace SocketHttpListener
}
}
- private static byte[] readBytes(this Stream stream, byte[] buffer, int offset, int length)
+ private static async Task<byte[]> ReadBytesAsync(this Stream stream, byte[] buffer, int offset, int length)
{
- var len = stream.Read(buffer, offset, length);
+ var len = await stream.ReadAsync(buffer, offset, length).ConfigureAwait(false);
if (len < 1)
return buffer.SubArray(0, offset);
var tmp = 0;
while (len < length)
{
- tmp = stream.Read(buffer, offset + len, length - len);
+ tmp = await stream.ReadAsync(buffer, offset + len, length - len).ConfigureAwait(false);
if (tmp < 1)
+ {
break;
+ }
len += tmp;
}
@@ -95,10 +97,9 @@ namespace SocketHttpListener
: buffer;
}
- private static bool readBytes(
- this Stream stream, byte[] buffer, int offset, int length, Stream dest)
+ private static async Task<bool> ReadBytesAsync(this Stream stream, byte[] buffer, int offset, int length, Stream dest)
{
- var bytes = stream.readBytes(buffer, offset, length);
+ var bytes = await stream.ReadBytesAsync(buffer, offset, length).ConfigureAwait(false);
var len = bytes.Length;
dest.Write(bytes, 0, len);
@@ -109,16 +110,16 @@ namespace SocketHttpListener
#region Internal Methods
- internal static byte[] Append(this ushort code, string reason)
+ internal static async Task<byte[]> AppendAsync(this ushort code, string reason)
{
using (var buffer = new MemoryStream())
{
var tmp = code.ToByteArrayInternally(ByteOrder.Big);
- buffer.Write(tmp, 0, 2);
+ await buffer.WriteAsync(tmp, 0, 2).ConfigureAwait(false);
if (reason != null && reason.Length > 0)
{
tmp = Encoding.UTF8.GetBytes(reason);
- buffer.Write(tmp, 0, tmp.Length);
+ await buffer.WriteAsync(tmp, 0, tmp.Length).ConfigureAwait(false);
}
return buffer.ToArray();
@@ -331,12 +332,10 @@ namespace SocketHttpListener
: string.Format("\"{0}\"", value.Replace("\"", "\\\""));
}
- internal static byte[] ReadBytes(this Stream stream, int length)
- {
- return stream.readBytes(new byte[length], 0, length);
- }
+ internal static Task<byte[]> ReadBytesAsync(this Stream stream, int length)
+ => stream.ReadBytesAsync(new byte[length], 0, length);
- internal static byte[] ReadBytes(this Stream stream, long length, int bufferLength)
+ internal static async Task<byte[]> ReadBytesAsync(this Stream stream, long length, int bufferLength)
{
using (var result = new MemoryStream())
{
@@ -347,7 +346,7 @@ namespace SocketHttpListener
var end = false;
for (long i = 0; i < count; i++)
{
- if (!stream.readBytes(buffer, 0, bufferLength, result))
+ if (!await stream.ReadBytesAsync(buffer, 0, bufferLength, result).ConfigureAwait(false))
{
end = true;
break;
@@ -355,26 +354,14 @@ namespace SocketHttpListener
}
if (!end && rem > 0)
- stream.readBytes(new byte[rem], 0, rem, result);
+ {
+ await stream.ReadBytesAsync(new byte[rem], 0, rem, result).ConfigureAwait(false);
+ }
return result.ToArray();
}
}
- internal static async Task<byte[]> ReadBytesAsync(this Stream stream, int length)
- {
- var buffer = new byte[length];
-
- var len = await stream.ReadAsync(buffer, 0, length).ConfigureAwait(false);
- var bytes = len < 1
- ? new byte[0]
- : len < length
- ? stream.readBytes(buffer, len, length - len)
- : buffer;
-
- return bytes;
- }
-
internal static string RemovePrefix(this string value, params string[] prefixes)
{
var i = 0;
@@ -493,19 +480,16 @@ namespace SocketHttpListener
return string.Format("{0}; {1}", m, parameters.ToString("; "));
}
- internal static List<TSource> ToList<TSource>(this IEnumerable<TSource> source)
- {
- return new List<TSource>(source);
- }
-
internal static ushort ToUInt16(this byte[] src, ByteOrder srcOrder)
{
- return BitConverter.ToUInt16(src.ToHostOrder(srcOrder), 0);
+ src.ToHostOrder(srcOrder);
+ return BitConverter.ToUInt16(src, 0);
}
internal static ulong ToUInt64(this byte[] src, ByteOrder srcOrder)
{
- return BitConverter.ToUInt64(src.ToHostOrder(srcOrder), 0);
+ src.ToHostOrder(srcOrder);
+ return BitConverter.ToUInt64(src, 0);
}
internal static string TrimEndSlash(this string value)
@@ -852,14 +836,17 @@ namespace SocketHttpListener
/// <exception cref="ArgumentNullException">
/// <paramref name="src"/> is <see langword="null"/>.
/// </exception>
- public static byte[] ToHostOrder(this byte[] src, ByteOrder srcOrder)
+ public static void ToHostOrder(this byte[] src, ByteOrder srcOrder)
{
if (src == null)
+ {
throw new ArgumentNullException(nameof(src));
+ }
- return src.Length > 1 && !srcOrder.IsHostOrder()
- ? src.Reverse()
- : src;
+ if (src.Length > 1 && !srcOrder.IsHostOrder())
+ {
+ Array.Reverse(src);
+ }
}
/// <summary>
diff --git a/SocketHttpListener/Net/HttpListener.cs b/SocketHttpListener/Net/HttpListener.cs
index b80180679..f17036a21 100644
--- a/SocketHttpListener/Net/HttpListener.cs
+++ b/SocketHttpListener/Net/HttpListener.cs
@@ -3,7 +3,6 @@ using System.Collections;
using System.Collections.Generic;
using System.Net;
using System.Security.Cryptography.X509Certificates;
-using MediaBrowser.Common.Net;
using MediaBrowser.Model.Cryptography;
using MediaBrowser.Model.IO;
using MediaBrowser.Model.Net;
@@ -18,47 +17,55 @@ namespace SocketHttpListener.Net
internal ISocketFactory SocketFactory { get; private set; }
internal IFileSystem FileSystem { get; private set; }
internal IStreamHelper StreamHelper { get; private set; }
- internal INetworkManager NetworkManager { get; private set; }
internal IEnvironmentInfo EnvironmentInfo { get; private set; }
public bool EnableDualMode { get; set; }
- AuthenticationSchemes auth_schemes;
- HttpListenerPrefixCollection prefixes;
- AuthenticationSchemeSelector auth_selector;
- string realm;
- bool unsafe_ntlm_auth;
- bool listening;
- bool disposed;
+ private AuthenticationSchemes auth_schemes;
+ private HttpListenerPrefixCollection prefixes;
+ private AuthenticationSchemeSelector auth_selector;
+ private string realm;
+ private bool unsafe_ntlm_auth;
+ private bool listening;
+ private bool disposed;
- Dictionary<HttpListenerContext, HttpListenerContext> registry; // Dictionary<HttpListenerContext,HttpListenerContext>
- Dictionary<HttpConnection, HttpConnection> connections;
+ private Dictionary<HttpListenerContext, HttpListenerContext> registry;
+ private Dictionary<HttpConnection, HttpConnection> connections;
private ILogger _logger;
private X509Certificate _certificate;
public Action<HttpListenerContext> OnContext { get; set; }
- public HttpListener(ILogger logger, ICryptoProvider cryptoProvider, ISocketFactory socketFactory,
- INetworkManager networkManager, IStreamHelper streamHelper, IFileSystem fileSystem,
+ public HttpListener(
+ ILogger logger,
+ ICryptoProvider cryptoProvider,
+ ISocketFactory socketFactory,
+ IStreamHelper streamHelper,
+ IFileSystem fileSystem,
IEnvironmentInfo environmentInfo)
{
_logger = logger;
CryptoProvider = cryptoProvider;
SocketFactory = socketFactory;
- NetworkManager = networkManager;
StreamHelper = streamHelper;
FileSystem = fileSystem;
EnvironmentInfo = environmentInfo;
+
prefixes = new HttpListenerPrefixCollection(logger, this);
registry = new Dictionary<HttpListenerContext, HttpListenerContext>();
connections = new Dictionary<HttpConnection, HttpConnection>();
auth_schemes = AuthenticationSchemes.Anonymous;
}
- public HttpListener(ILogger logger, X509Certificate certificate, ICryptoProvider cryptoProvider,
- ISocketFactory socketFactory, INetworkManager networkManager, IStreamHelper streamHelper,
- IFileSystem fileSystem, IEnvironmentInfo environmentInfo)
- : this(logger, cryptoProvider, socketFactory, networkManager, streamHelper, fileSystem, environmentInfo)
+ public HttpListener(
+ ILogger logger,
+ X509Certificate certificate,
+ ICryptoProvider cryptoProvider,
+ ISocketFactory socketFactory,
+ IStreamHelper streamHelper,
+ IFileSystem fileSystem,
+ IEnvironmentInfo environmentInfo)
+ : this(logger, cryptoProvider, socketFactory, streamHelper, fileSystem, environmentInfo)
{
_certificate = certificate;
}
diff --git a/SocketHttpListener/Net/HttpListenerPrefixCollection.cs b/SocketHttpListener/Net/HttpListenerPrefixCollection.cs
index 97dc6797c..400a1adb6 100644
--- a/SocketHttpListener/Net/HttpListenerPrefixCollection.cs
+++ b/SocketHttpListener/Net/HttpListenerPrefixCollection.cs
@@ -7,18 +7,18 @@ namespace SocketHttpListener.Net
{
public class HttpListenerPrefixCollection : ICollection<string>, IEnumerable<string>, IEnumerable
{
- List<string> prefixes = new List<string>();
- HttpListener listener;
+ private List<string> _prefixes = new List<string>();
+ private HttpListener _listener;
private ILogger _logger;
internal HttpListenerPrefixCollection(ILogger logger, HttpListener listener)
{
_logger = logger;
- this.listener = listener;
+ _listener = listener;
}
- public int Count => prefixes.Count;
+ public int Count => _prefixes.Count;
public bool IsReadOnly => false;
@@ -26,61 +26,90 @@ namespace SocketHttpListener.Net
public void Add(string uriPrefix)
{
- listener.CheckDisposed();
+ _listener.CheckDisposed();
//ListenerPrefix.CheckUri(uriPrefix);
- if (prefixes.Contains(uriPrefix))
+ if (_prefixes.Contains(uriPrefix))
+ {
return;
+ }
- prefixes.Add(uriPrefix);
- if (listener.IsListening)
- HttpEndPointManager.AddPrefix(_logger, uriPrefix, listener);
+ _prefixes.Add(uriPrefix);
+ if (_listener.IsListening)
+ {
+ HttpEndPointManager.AddPrefix(_logger, uriPrefix, _listener);
+ }
+ }
+
+ public void AddRange(IEnumerable<string> uriPrefixes)
+ {
+ _listener.CheckDisposed();
+
+ foreach (var uriPrefix in uriPrefixes)
+ {
+ if (_prefixes.Contains(uriPrefix))
+ {
+ continue;
+ }
+
+ _prefixes.Add(uriPrefix);
+ if (_listener.IsListening)
+ {
+ HttpEndPointManager.AddPrefix(_logger, uriPrefix, _listener);
+ }
+ }
}
public void Clear()
{
- listener.CheckDisposed();
- prefixes.Clear();
- if (listener.IsListening)
- HttpEndPointManager.RemoveListener(_logger, listener);
+ _listener.CheckDisposed();
+ _prefixes.Clear();
+ if (_listener.IsListening)
+ {
+ HttpEndPointManager.RemoveListener(_logger, _listener);
+ }
}
public bool Contains(string uriPrefix)
{
- listener.CheckDisposed();
- return prefixes.Contains(uriPrefix);
+ _listener.CheckDisposed();
+ return _prefixes.Contains(uriPrefix);
}
public void CopyTo(string[] array, int offset)
{
- listener.CheckDisposed();
- prefixes.CopyTo(array, offset);
+ _listener.CheckDisposed();
+ _prefixes.CopyTo(array, offset);
}
public void CopyTo(Array array, int offset)
{
- listener.CheckDisposed();
- ((ICollection)prefixes).CopyTo(array, offset);
+ _listener.CheckDisposed();
+ ((ICollection)_prefixes).CopyTo(array, offset);
}
public IEnumerator<string> GetEnumerator()
{
- return prefixes.GetEnumerator();
+ return _prefixes.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
- return prefixes.GetEnumerator();
+ return _prefixes.GetEnumerator();
}
public bool Remove(string uriPrefix)
{
- listener.CheckDisposed();
+ _listener.CheckDisposed();
if (uriPrefix == null)
+ {
throw new ArgumentNullException(nameof(uriPrefix));
+ }
- bool result = prefixes.Remove(uriPrefix);
- if (result && listener.IsListening)
- HttpEndPointManager.RemovePrefix(_logger, uriPrefix, listener);
+ bool result = _prefixes.Remove(uriPrefix);
+ if (result && _listener.IsListening)
+ {
+ HttpEndPointManager.RemovePrefix(_logger, uriPrefix, _listener);
+ }
return result;
}
diff --git a/SocketHttpListener/WebSocket.cs b/SocketHttpListener/WebSocket.cs
index 128bc8b97..0dcb6a64b 100644
--- a/SocketHttpListener/WebSocket.cs
+++ b/SocketHttpListener/WebSocket.cs
@@ -30,9 +30,9 @@ namespace SocketHttpListener
private CookieCollection _cookies;
private AutoResetEvent _exitReceiving;
private object _forConn;
- private object _forEvent;
+ private readonly SemaphoreSlim _forEvent = new SemaphoreSlim(1, 1);
private object _forMessageEventQueue;
- private object _forSend;
+ private readonly SemaphoreSlim _forSend = new SemaphoreSlim(1, 1);
private const string _guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
private Queue<MessageEventArgs> _messageEventQueue;
private string _protocol;
@@ -109,12 +109,15 @@ namespace SocketHttpListener
#region Private Methods
- private void close(CloseStatusCode code, string reason, bool wait)
+ private async Task CloseAsync(CloseStatusCode code, string reason, bool wait)
{
- close(new PayloadData(((ushort)code).Append(reason)), !code.IsReserved(), wait);
+ await CloseAsync(new PayloadData(
+ await ((ushort)code).AppendAsync(reason).ConfigureAwait(false)),
+ !code.IsReserved(),
+ wait).ConfigureAwait(false);
}
- private void close(PayloadData payload, bool send, bool wait)
+ private async Task CloseAsync(PayloadData payload, bool send, bool wait)
{
lock (_forConn)
{
@@ -126,11 +129,12 @@ namespace SocketHttpListener
_readyState = WebSocketState.CloseSent;
}
- var e = new CloseEventArgs(payload);
- e.WasClean =
- closeHandshake(
+ var e = new CloseEventArgs(payload)
+ {
+ WasClean = await CloseHandshakeAsync(
send ? WebSocketFrame.CreateCloseFrame(Mask.Unmask, payload).ToByteArray() : null,
- wait ? 1000 : 0);
+ wait ? 1000 : 0).ConfigureAwait(false)
+ };
_readyState = WebSocketState.Closed;
try
@@ -143,9 +147,9 @@ namespace SocketHttpListener
}
}
- private bool closeHandshake(byte[] frameAsBytes, int millisecondsTimeout)
+ private async Task<bool> CloseHandshakeAsync(byte[] frameAsBytes, int millisecondsTimeout)
{
- var sent = frameAsBytes != null && writeBytes(frameAsBytes);
+ var sent = frameAsBytes != null && await WriteBytesAsync(frameAsBytes).ConfigureAwait(false);
var received =
millisecondsTimeout == 0 ||
(sent && _exitReceiving != null && _exitReceiving.WaitOne(millisecondsTimeout));
@@ -189,11 +193,11 @@ namespace SocketHttpListener
_context = null;
}
- private bool concatenateFragmentsInto(Stream dest)
+ private async Task<bool> ConcatenateFragmentsIntoAsync(Stream dest)
{
while (true)
{
- var frame = WebSocketFrame.Read(_stream, true);
+ var frame = await WebSocketFrame.ReadAsync(_stream, true).ConfigureAwait(false);
if (frame.IsFinal)
{
/* FINAL */
@@ -221,7 +225,7 @@ namespace SocketHttpListener
// CLOSE
if (frame.IsClose)
- return processCloseFrame(frame);
+ return await ProcessCloseFrameAsync(frame).ConfigureAwait(false);
}
else
{
@@ -236,10 +240,10 @@ namespace SocketHttpListener
}
// ?
- return processUnsupportedFrame(
+ return await ProcessUnsupportedFrameAsync(
frame,
CloseStatusCode.IncorrectData,
- "An incorrect data has been received while receiving fragmented data.");
+ "An incorrect data has been received while receiving fragmented data.").ConfigureAwait(false);
}
return true;
@@ -299,44 +303,42 @@ namespace SocketHttpListener
_compression = CompressionMethod.None;
_cookies = new CookieCollection();
_forConn = new object();
- _forEvent = new object();
- _forSend = new object();
_messageEventQueue = new Queue<MessageEventArgs>();
_forMessageEventQueue = ((ICollection)_messageEventQueue).SyncRoot;
_readyState = WebSocketState.Connecting;
}
- private void open()
+ private async Task OpenAsync()
{
try
{
startReceiving();
- lock (_forEvent)
- {
- try
- {
- if (OnOpen != null)
- {
- OnOpen(this, EventArgs.Empty);
- }
- }
- catch (Exception ex)
- {
- processException(ex, "An exception has occurred while OnOpen.");
- }
- }
}
catch (Exception ex)
{
- processException(ex, "An exception has occurred while opening.");
+ await ProcessExceptionAsync(ex, "An exception has occurred while opening.").ConfigureAwait(false);
+ }
+
+ await _forEvent.WaitAsync().ConfigureAwait(false);
+ try
+ {
+ OnOpen?.Invoke(this, EventArgs.Empty);
+ }
+ catch (Exception ex)
+ {
+ await ProcessExceptionAsync(ex, "An exception has occurred while OnOpen.").ConfigureAwait(false);
+ }
+ finally
+ {
+ _forEvent.Release();
}
}
- private bool processCloseFrame(WebSocketFrame frame)
+ private async Task<bool> ProcessCloseFrameAsync(WebSocketFrame frame)
{
var payload = frame.PayloadData;
- close(payload, !payload.ContainsReservedCloseStatusCode, false);
+ await CloseAsync(payload, !payload.ContainsReservedCloseStatusCode, false).ConfigureAwait(false);
return false;
}
@@ -352,7 +354,7 @@ namespace SocketHttpListener
return true;
}
- private void processException(Exception exception, string message)
+ private async Task ProcessExceptionAsync(Exception exception, string message)
{
var code = CloseStatusCode.Abnormal;
var reason = message;
@@ -365,25 +367,31 @@ namespace SocketHttpListener
error(message ?? code.GetMessage(), exception);
if (_readyState == WebSocketState.Connecting)
- Close(HttpStatusCode.BadRequest);
+ {
+ await CloseAsync(HttpStatusCode.BadRequest).ConfigureAwait(false);
+ }
else
- close(code, reason ?? code.GetMessage(), false);
+ {
+ await CloseAsync(code, reason ?? code.GetMessage(), false).ConfigureAwait(false);
+ }
}
- private bool processFragmentedFrame(WebSocketFrame frame)
+ private Task<bool> ProcessFragmentedFrameAsync(WebSocketFrame frame)
{
return frame.IsContinuation // Not first fragment
- ? true
- : processFragments(frame);
+ ? Task.FromResult(true)
+ : ProcessFragmentsAsync(frame);
}
- private bool processFragments(WebSocketFrame first)
+ private async Task<bool> ProcessFragmentsAsync(WebSocketFrame first)
{
using (var buff = new MemoryStream())
{
buff.WriteBytes(first.PayloadData.ApplicationData);
- if (!concatenateFragmentsInto(buff))
+ if (!await ConcatenateFragmentsIntoAsync(buff).ConfigureAwait(false))
+ {
return false;
+ }
byte[] data;
if (_compression != CompressionMethod.None)
@@ -412,36 +420,38 @@ namespace SocketHttpListener
return true;
}
- private bool processUnsupportedFrame(WebSocketFrame frame, CloseStatusCode code, string reason)
+ private async Task<bool> ProcessUnsupportedFrameAsync(WebSocketFrame frame, CloseStatusCode code, string reason)
{
- processException(new WebSocketException(code, reason), null);
+ await ProcessExceptionAsync(new WebSocketException(code, reason), null).ConfigureAwait(false);
return false;
}
- private bool processWebSocketFrame(WebSocketFrame frame)
+ private Task<bool> ProcessWebSocketFrameAsync(WebSocketFrame frame)
{
+ // TODO: @bond change to if/else chain
return frame.IsCompressed && _compression == CompressionMethod.None
- ? processUnsupportedFrame(
+ ? ProcessUnsupportedFrameAsync(
frame,
CloseStatusCode.IncorrectData,
"A compressed data has been received without available decompression method.")
: frame.IsFragmented
- ? processFragmentedFrame(frame)
+ ? ProcessFragmentedFrameAsync(frame)
: frame.IsData
- ? processDataFrame(frame)
+ ? Task.FromResult(processDataFrame(frame))
: frame.IsPing
- ? processPingFrame(frame)
+ ? Task.FromResult(processPingFrame(frame))
: frame.IsPong
- ? processPongFrame(frame)
+ ? Task.FromResult(processPongFrame(frame))
: frame.IsClose
- ? processCloseFrame(frame)
- : processUnsupportedFrame(frame, CloseStatusCode.PolicyViolation, null);
+ ? ProcessCloseFrameAsync(frame)
+ : ProcessUnsupportedFrameAsync(frame, CloseStatusCode.PolicyViolation, null);
}
- private bool send(Opcode opcode, Stream stream)
+ private async Task<bool> SendAsync(Opcode opcode, Stream stream)
{
- lock (_forSend)
+ await _forSend.WaitAsync().ConfigureAwait(false);
+ try
{
var src = stream;
var compressed = false;
@@ -454,7 +464,7 @@ namespace SocketHttpListener
compressed = true;
}
- sent = send(opcode, Mask.Unmask, stream, compressed);
+ sent = await SendAsync(opcode, Mask.Unmask, stream, compressed).ConfigureAwait(false);
if (!sent)
error("Sending a data has been interrupted.");
}
@@ -472,16 +482,20 @@ namespace SocketHttpListener
return sent;
}
+ finally
+ {
+ _forSend.Release();
+ }
}
- private bool send(Opcode opcode, Mask mask, Stream stream, bool compressed)
+ private async Task<bool> SendAsync(Opcode opcode, Mask mask, Stream stream, bool compressed)
{
var len = stream.Length;
/* Not fragmented */
if (len == 0)
- return send(Fin.Final, opcode, mask, new byte[0], compressed);
+ return await SendAsync(Fin.Final, opcode, mask, new byte[0], compressed).ConfigureAwait(false);
var quo = len / FragmentLength;
var rem = (int)(len % FragmentLength);
@@ -490,26 +504,26 @@ namespace SocketHttpListener
if (quo == 0)
{
buff = new byte[rem];
- return stream.Read(buff, 0, rem) == rem &&
- send(Fin.Final, opcode, mask, buff, compressed);
+ return await stream.ReadAsync(buff, 0, rem).ConfigureAwait(false) == rem &&
+ await SendAsync(Fin.Final, opcode, mask, buff, compressed).ConfigureAwait(false);
}
buff = new byte[FragmentLength];
if (quo == 1 && rem == 0)
- return stream.Read(buff, 0, FragmentLength) == FragmentLength &&
- send(Fin.Final, opcode, mask, buff, compressed);
+ return await stream.ReadAsync(buff, 0, FragmentLength).ConfigureAwait(false) == FragmentLength &&
+ await SendAsync(Fin.Final, opcode, mask, buff, compressed).ConfigureAwait(false);
/* Send fragmented */
// Begin
- if (stream.Read(buff, 0, FragmentLength) != FragmentLength ||
- !send(Fin.More, opcode, mask, buff, compressed))
+ if (await stream.ReadAsync(buff, 0, FragmentLength).ConfigureAwait(false) != FragmentLength ||
+ !await SendAsync(Fin.More, opcode, mask, buff, compressed).ConfigureAwait(false))
return false;
var n = rem == 0 ? quo - 2 : quo - 1;
for (long i = 0; i < n; i++)
- if (stream.Read(buff, 0, FragmentLength) != FragmentLength ||
- !send(Fin.More, Opcode.Cont, mask, buff, compressed))
+ if (await stream.ReadAsync(buff, 0, FragmentLength).ConfigureAwait(false) != FragmentLength ||
+ !await SendAsync(Fin.More, Opcode.Cont, mask, buff, compressed).ConfigureAwait(false))
return false;
// End
@@ -518,98 +532,88 @@ namespace SocketHttpListener
else
buff = new byte[rem];
- return stream.Read(buff, 0, rem) == rem &&
- send(Fin.Final, Opcode.Cont, mask, buff, compressed);
+ return await stream.ReadAsync(buff, 0, rem).ConfigureAwait(false) == rem &&
+ await SendAsync(Fin.Final, Opcode.Cont, mask, buff, compressed).ConfigureAwait(false);
}
- private bool send(Fin fin, Opcode opcode, Mask mask, byte[] data, bool compressed)
+ private Task<bool> SendAsync(Fin fin, Opcode opcode, Mask mask, byte[] data, bool compressed)
{
lock (_forConn)
{
if (_readyState != WebSocketState.Open)
{
- return false;
+ return Task.FromResult(false);
}
- return writeBytes(
+ return WriteBytesAsync(
WebSocketFrame.CreateWebSocketFrame(fin, opcode, mask, data, compressed).ToByteArray());
}
}
- private Task sendAsync(Opcode opcode, Stream stream)
- {
- var completionSource = new TaskCompletionSource<bool>();
- Task.Run(() =>
- {
- try
- {
- send(opcode, stream);
- completionSource.TrySetResult(true);
- }
- catch (Exception ex)
- {
- completionSource.TrySetException(ex);
- }
- });
- return completionSource.Task;
- }
-
// As server
- private bool sendHttpResponse(HttpResponse response)
- {
- return writeBytes(response.ToByteArray());
- }
+ private Task<bool> SendHttpResponseAsync(HttpResponse response)
+ => WriteBytesAsync(response.ToByteArray());
private void startReceiving()
{
if (_messageEventQueue.Count > 0)
+ {
_messageEventQueue.Clear();
+ }
_exitReceiving = new AutoResetEvent(false);
_receivePong = new AutoResetEvent(false);
Action receive = null;
- receive = () => WebSocketFrame.ReadAsync(
- _stream,
- true,
- frame =>
- {
- if (processWebSocketFrame(frame) && _readyState != WebSocketState.Closed)
- {
- receive();
-
- if (!frame.IsData)
- return;
-
- lock (_forEvent)
- {
- try
- {
- var e = dequeueFromMessageEventQueue();
- if (e != null && _readyState == WebSocketState.Open)
- OnMessage.Emit(this, e);
- }
- catch (Exception ex)
- {
- processException(ex, "An exception has occurred while OnMessage.");
- }
- }
- }
- else if (_exitReceiving != null)
- {
- _exitReceiving.Set();
- }
- },
- ex => processException(ex, "An exception has occurred while receiving a message."));
+ receive = async () => await WebSocketFrame.ReadAsync(
+ _stream,
+ true,
+ async frame =>
+ {
+ if (await ProcessWebSocketFrameAsync(frame).ConfigureAwait(false) && _readyState != WebSocketState.Closed)
+ {
+ receive();
+
+ if (!frame.IsData)
+ {
+ return;
+ }
+
+ await _forEvent.WaitAsync().ConfigureAwait(false);
+
+ try
+ {
+ var e = dequeueFromMessageEventQueue();
+ if (e != null && _readyState == WebSocketState.Open)
+ {
+ OnMessage.Emit(this, e);
+ }
+ }
+ catch (Exception ex)
+ {
+ await ProcessExceptionAsync(ex, "An exception has occurred while OnMessage.").ConfigureAwait(false);
+ }
+ finally
+ {
+ _forEvent.Release();
+ }
+
+ }
+ else if (_exitReceiving != null)
+ {
+ _exitReceiving.Set();
+ }
+ },
+ async ex => await ProcessExceptionAsync(ex, "An exception has occurred while receiving a message.")).ConfigureAwait(false);
receive();
}
- private bool writeBytes(byte[] data)
+ private async Task<bool> WriteBytesAsync(byte[] data)
{
try
{
- _stream.Write(data, 0, data.Length);
+ await _stream.WriteAsync(data, 0, data.Length).ConfigureAwait(false);
return true;
}
catch (Exception)
@@ -623,10 +627,10 @@ namespace SocketHttpListener
#region Internal Methods
// As server
- internal void Close(HttpResponse response)
+ internal async Task CloseAsync(HttpResponse response)
{
_readyState = WebSocketState.CloseSent;
- sendHttpResponse(response);
+ await SendHttpResponseAsync(response).ConfigureAwait(false);
closeServerResources();
@@ -634,22 +638,20 @@ namespace SocketHttpListener
}
// As server
- internal void Close(HttpStatusCode code)
- {
- Close(createHandshakeCloseResponse(code));
- }
+ internal Task CloseAsync(HttpStatusCode code)
+ => CloseAsync(createHandshakeCloseResponse(code));
// As server
- public void ConnectAsServer()
+ public async Task ConnectAsServer()
{
try
{
_readyState = WebSocketState.Open;
- open();
+ await OpenAsync().ConfigureAwait(false);
}
catch (Exception ex)
{
- processException(ex, "An exception has occurred while connecting.");
+ await ProcessExceptionAsync(ex, "An exception has occurred while connecting.").ConfigureAwait(false);
}
}
@@ -660,18 +662,18 @@ namespace SocketHttpListener
/// <summary>
/// Closes the WebSocket connection, and releases all associated resources.
/// </summary>
- public void Close()
+ public Task CloseAsync()
{
var msg = _readyState.CheckIfClosable();
if (msg != null)
{
error(msg);
- return;
+ return Task.CompletedTask;
}
var send = _readyState == WebSocketState.Open;
- close(new PayloadData(), send, send);
+ return CloseAsync(new PayloadData(), send, send);
}
/// <summary>
@@ -689,11 +691,11 @@ namespace SocketHttpListener
/// <param name="reason">
/// A <see cref="string"/> that represents the reason for the close.
/// </param>
- public void Close(CloseStatusCode code, string reason)
+ public async Task CloseAsync(CloseStatusCode code, string reason)
{
byte[] data = null;
var msg = _readyState.CheckIfClosable() ??
- (data = ((ushort)code).Append(reason)).CheckIfValidControlData("reason");
+ (data = await ((ushort)code).AppendAsync(reason).ConfigureAwait(false)).CheckIfValidControlData("reason");
if (msg != null)
{
@@ -703,7 +705,7 @@ namespace SocketHttpListener
}
var send = _readyState == WebSocketState.Open && !code.IsReserved();
- close(new PayloadData(data), send, send);
+ await CloseAsync(new PayloadData(data), send, send).ConfigureAwait(false);
}
/// <summary>
@@ -728,7 +730,7 @@ namespace SocketHttpListener
throw new Exception(msg);
}
- return sendAsync(Opcode.Binary, new MemoryStream(data));
+ return SendAsync(Opcode.Binary, new MemoryStream(data));
}
/// <summary>
@@ -753,7 +755,7 @@ namespace SocketHttpListener
throw new Exception(msg);
}
- return sendAsync(Opcode.Text, new MemoryStream(Encoding.UTF8.GetBytes(data)));
+ return SendAsync(Opcode.Text, new MemoryStream(Encoding.UTF8.GetBytes(data)));
}
#endregion
@@ -768,7 +770,7 @@ namespace SocketHttpListener
/// </remarks>
void IDisposable.Dispose()
{
- Close(CloseStatusCode.Away, null);
+ CloseAsync(CloseStatusCode.Away, null).GetAwaiter().GetResult();
}
#endregion
diff --git a/SocketHttpListener/WebSocketFrame.cs b/SocketHttpListener/WebSocketFrame.cs
index 74ed23c45..8ec64026b 100644
--- a/SocketHttpListener/WebSocketFrame.cs
+++ b/SocketHttpListener/WebSocketFrame.cs
@@ -2,6 +2,7 @@ using System;
using System.Collections;
using System.Collections.Generic;
using System.IO;
+using System.Threading.Tasks;
namespace SocketHttpListener
{
@@ -177,7 +178,7 @@ namespace SocketHttpListener
return opcode == Opcode.Text || opcode == Opcode.Binary;
}
- private static WebSocketFrame read(byte[] header, Stream stream, bool unmask)
+ private static async Task<WebSocketFrame> ReadAsync(byte[] header, Stream stream, bool unmask)
{
/* Header */
@@ -229,7 +230,7 @@ namespace SocketHttpListener
? 2
: 8;
- var extPayloadLen = size > 0 ? stream.ReadBytes(size) : new byte[0];
+ var extPayloadLen = size > 0 ? await stream.ReadBytesAsync(size).ConfigureAwait(false) : Array.Empty<byte>();
if (size > 0 && extPayloadLen.Length != size)
throw new WebSocketException(
"The 'Extended Payload Length' of a frame cannot be read from the data source.");
@@ -239,7 +240,7 @@ namespace SocketHttpListener
/* Masking Key */
var masked = mask == Mask.Mask;
- var maskingKey = masked ? stream.ReadBytes(4) : new byte[0];
+ var maskingKey = masked ? await stream.ReadBytesAsync(4).ConfigureAwait(false) : Array.Empty<byte>();
if (masked && maskingKey.Length != 4)
throw new WebSocketException(
"The 'Masking Key' of a frame cannot be read from the data source.");
@@ -264,8 +265,8 @@ namespace SocketHttpListener
"The length of 'Payload Data' of a frame is greater than the allowable length.");
data = payloadLen > 126
- ? stream.ReadBytes((long)len, 1024)
- : stream.ReadBytes((int)len);
+ ? await stream.ReadBytesAsync((long)len, 1024).ConfigureAwait(false)
+ : await stream.ReadBytesAsync((int)len).ConfigureAwait(false);
//if (data.LongLength != (long)len)
// throw new WebSocketException(
@@ -273,7 +274,7 @@ namespace SocketHttpListener
}
else
{
- data = new byte[0];
+ data = Array.Empty<byte>();
}
var payload = new PayloadData(data, masked);
@@ -281,7 +282,7 @@ namespace SocketHttpListener
{
payload.Mask(maskingKey);
frame._mask = Mask.Unmask;
- frame._maskingKey = new byte[0];
+ frame._maskingKey = Array.Empty<byte>();
}
frame._payloadData = payload;
@@ -302,10 +303,10 @@ namespace SocketHttpListener
return new WebSocketFrame(Opcode.Close, mask, payload);
}
- internal static WebSocketFrame CreateCloseFrame(Mask mask, CloseStatusCode code, string reason)
+ internal static async Task<WebSocketFrame> CreateCloseFrameAsync(Mask mask, CloseStatusCode code, string reason)
{
return new WebSocketFrame(
- Opcode.Close, mask, new PayloadData(((ushort)code).Append(reason)));
+ Opcode.Close, mask, new PayloadData(await ((ushort)code).AppendAsync(reason).ConfigureAwait(false)));
}
internal static WebSocketFrame CreatePingFrame(Mask mask)
@@ -329,41 +330,39 @@ namespace SocketHttpListener
return new WebSocketFrame(fin, opcode, mask, new PayloadData(data), compressed);
}
- internal static WebSocketFrame Read(Stream stream)
- {
- return Read(stream, true);
- }
+ internal static Task<WebSocketFrame> ReadAsync(Stream stream)
+ => ReadAsync(stream, true);
- internal static WebSocketFrame Read(Stream stream, bool unmask)
+ internal static async Task<WebSocketFrame> ReadAsync(Stream stream, bool unmask)
{
- var header = stream.ReadBytes(2);
+ var header = await stream.ReadBytesAsync(2).ConfigureAwait(false);
if (header.Length != 2)
+ {
throw new WebSocketException(
"The header part of a frame cannot be read from the data source.");
+ }
- return read(header, stream, unmask);
+ return await ReadAsync(header, stream, unmask).ConfigureAwait(false);
}
- internal static async void ReadAsync(
+ internal static async Task ReadAsync(
Stream stream, bool unmask, Action<WebSocketFrame> completed, Action<Exception> error)
{
try
{
var header = await stream.ReadBytesAsync(2).ConfigureAwait(false);
if (header.Length != 2)
+ {
throw new WebSocketException(
"The header part of a frame cannot be read from the data source.");
+ }
- var frame = read(header, stream, unmask);
- if (completed != null)
- completed(frame);
+ var frame = await ReadAsync(header, stream, unmask).ConfigureAwait(false);
+ completed?.Invoke(frame);
}
catch (Exception ex)
{
- if (error != null)
- {
- error(ex);
- }
+ error.Invoke(ex);
}
}