diff options
| author | stefan <stefan@hegedues.at> | 2018-09-12 19:26:21 +0200 |
|---|---|---|
| committer | stefan <stefan@hegedues.at> | 2018-09-12 19:26:21 +0200 |
| commit | 48facb797ed912e4ea6b04b17d1ff190ac2daac4 (patch) | |
| tree | 8dae77a31670a888d733484cb17dd4077d5444e8 /SocketHttpListener | |
| parent | c32d8656382a0eacb301692e0084377fc433ae9b (diff) | |
Update to 3.5.2 and .net core 2.1
Diffstat (limited to 'SocketHttpListener')
32 files changed, 2810 insertions, 2591 deletions
diff --git a/SocketHttpListener/Ext.cs b/SocketHttpListener/Ext.cs index 87f0887ed..125775180 100644 --- a/SocketHttpListener/Ext.cs +++ b/SocketHttpListener/Ext.cs @@ -7,9 +7,8 @@ using System.Net; using System.Text; using System.Threading.Tasks; using MediaBrowser.Model.Services; -using SocketHttpListener.Net; -using HttpListenerResponse = SocketHttpListener.Net.HttpListenerResponse; using HttpStatusCode = SocketHttpListener.Net.HttpStatusCode; +using WebSocketState = System.Net.WebSockets.WebSocketState; namespace SocketHttpListener { @@ -129,7 +128,7 @@ namespace SocketHttpListener internal static string CheckIfClosable(this WebSocketState state) { - return state == WebSocketState.Closing + return state == WebSocketState.CloseSent ? "While closing the WebSocket connection." : state == WebSocketState.Closed ? "The WebSocket connection has already been closed." @@ -140,7 +139,7 @@ namespace SocketHttpListener { return state == WebSocketState.Connecting ? "A WebSocket connection isn't established." - : state == WebSocketState.Closing + : state == WebSocketState.CloseSent ? "While closing the WebSocket connection." : state == WebSocketState.Closed ? "The WebSocket connection has already been closed." @@ -154,20 +153,6 @@ namespace SocketHttpListener : null; } - internal static string CheckIfValidSendData(this byte[] data) - { - return data == null - ? "'data' must not be null." - : null; - } - - internal static string CheckIfValidSendData(this string data) - { - return data == null - ? "'data' must not be null." - : null; - } - internal static Stream Compress(this Stream stream, CompressionMethod method) { return method == CompressionMethod.Deflate @@ -632,24 +617,6 @@ namespace SocketHttpListener } /// <summary> - /// Emits the specified <see cref="EventHandler"/> delegate if it isn't <see langword="null"/>. - /// </summary> - /// <param name="eventHandler"> - /// A <see cref="EventHandler"/> to emit. - /// </param> - /// <param name="sender"> - /// An <see cref="object"/> from which emits this <paramref name="eventHandler"/>. - /// </param> - /// <param name="e"> - /// A <see cref="EventArgs"/> that contains no event data. - /// </param> - public static void Emit(this EventHandler eventHandler, object sender, EventArgs e) - { - if (eventHandler != null) - eventHandler(sender, e); - } - - /// <summary> /// Emits the specified <c>EventHandler<TEventArgs></c> delegate /// if it isn't <see langword="null"/>. /// </summary> @@ -674,27 +641,6 @@ namespace SocketHttpListener } /// <summary> - /// Gets the collection of the HTTP cookies from the specified HTTP <paramref name="headers"/>. - /// </summary> - /// <returns> - /// A <see cref="CookieCollection"/> that receives a collection of the HTTP cookies. - /// </returns> - /// <param name="headers"> - /// A <see cref="QueryParamCollection"/> that contains a collection of the HTTP headers. - /// </param> - /// <param name="response"> - /// <c>true</c> if <paramref name="headers"/> is a collection of the response headers; - /// otherwise, <c>false</c>. - /// </param> - public static CookieCollection GetCookies(this QueryParamCollection headers, bool response) - { - var name = response ? "Set-Cookie" : "Cookie"; - return headers == null || !headers.Contains(name) - ? new CookieCollection() - : CookieHelper.Parse(headers[name], response); - } - - /// <summary> /// Gets the description of the specified HTTP status <paramref name="code"/>. /// </summary> /// <returns> @@ -709,52 +655,6 @@ namespace SocketHttpListener } /// <summary> - /// Gets the name from the specified <see cref="string"/> that contains a pair of name and - /// value separated by a separator string. - /// </summary> - /// <returns> - /// A <see cref="string"/> that represents the name if any; otherwise, <c>null</c>. - /// </returns> - /// <param name="nameAndValue"> - /// A <see cref="string"/> that contains a pair of name and value separated by a separator - /// string. - /// </param> - /// <param name="separator"> - /// A <see cref="string"/> that represents a separator string. - /// </param> - public static string GetName(this string nameAndValue, string separator) - { - return (nameAndValue != null && nameAndValue.Length > 0) && - (separator != null && separator.Length > 0) - ? nameAndValue.GetNameInternal(separator) - : null; - } - - /// <summary> - /// Gets the name and value from the specified <see cref="string"/> that contains a pair of - /// name and value separated by a separator string. - /// </summary> - /// <returns> - /// A <c>KeyValuePair<string, string></c> that represents the name and value if any. - /// </returns> - /// <param name="nameAndValue"> - /// A <see cref="string"/> that contains a pair of name and value separated by a separator - /// string. - /// </param> - /// <param name="separator"> - /// A <see cref="string"/> that represents a separator string. - /// </param> - public static KeyValuePair<string, string> GetNameAndValue( - this string nameAndValue, string separator) - { - var name = nameAndValue.GetName(separator); - var value = nameAndValue.GetValue(separator); - return name != null - ? new KeyValuePair<string, string>(name, value) - : new KeyValuePair<string, string>(null, null); - } - - /// <summary> /// Gets the description of the specified HTTP status <paramref name="code"/>. /// </summary> /// <returns> @@ -819,28 +719,6 @@ namespace SocketHttpListener } /// <summary> - /// Gets the value from the specified <see cref="string"/> that contains a pair of name and - /// value separated by a separator string. - /// </summary> - /// <returns> - /// A <see cref="string"/> that represents the value if any; otherwise, <c>null</c>. - /// </returns> - /// <param name="nameAndValue"> - /// A <see cref="string"/> that contains a pair of name and value separated by a separator - /// string. - /// </param> - /// <param name="separator"> - /// A <see cref="string"/> that represents a separator string. - /// </param> - public static string GetValue(this string nameAndValue, string separator) - { - return (nameAndValue != null && nameAndValue.Length > 0) && - (separator != null && separator.Length > 0) - ? nameAndValue.GetValueInternal(separator) - : null; - } - - /// <summary> /// Determines whether the specified <see cref="ByteOrder"/> is host /// (this computer architecture) byte order. /// </summary> diff --git a/SocketHttpListener/HttpResponse.cs b/SocketHttpListener/HttpResponse.cs index 5aca28c7c..154a3d8e9 100644 --- a/SocketHttpListener/HttpResponse.cs +++ b/SocketHttpListener/HttpResponse.cs @@ -7,6 +7,7 @@ using HttpStatusCode = SocketHttpListener.Net.HttpStatusCode; using HttpVersion = SocketHttpListener.Net.HttpVersion; using System.Linq; using MediaBrowser.Model.Services; +using SocketHttpListener.Net; namespace SocketHttpListener { @@ -51,10 +52,18 @@ namespace SocketHttpListener { get { - return Headers.GetCookies(true); + return GetCookies(Headers, true); } } + private CookieCollection GetCookies(QueryParamCollection headers, bool response) + { + var name = response ? "Set-Cookie" : "Cookie"; + return headers == null || !headers.Contains(name) + ? new CookieCollection() + : CookieHelper.Parse(headers[name], response); + } + public bool IsProxyAuthenticationRequired { get @@ -111,17 +120,6 @@ namespace SocketHttpListener return res; } - internal static HttpResponse CreateWebSocketResponse() - { - var res = new HttpResponse(HttpStatusCode.SwitchingProtocols); - - var headers = res.Headers; - headers["Upgrade"] = "websocket"; - headers["Connection"] = "Upgrade"; - - return res; - } - #endregion #region Public Methods diff --git a/SocketHttpListener/Net/AuthenticationTypes.cs b/SocketHttpListener/Net/AuthenticationTypes.cs new file mode 100644 index 000000000..df6b9d576 --- /dev/null +++ b/SocketHttpListener/Net/AuthenticationTypes.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace SocketHttpListener.Net +{ + internal class AuthenticationTypes + { + internal const string NTLM = "NTLM"; + internal const string Negotiate = "Negotiate"; + internal const string Basic = "Basic"; + } +} diff --git a/SocketHttpListener/Net/EndPointListener.cs b/SocketHttpListener/Net/EndPointListener.cs deleted file mode 100644 index 48c0ae7cb..000000000 --- a/SocketHttpListener/Net/EndPointListener.cs +++ /dev/null @@ -1,433 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.IO; -using System.Net; -using System.Net.Sockets; -using System.Security.Cryptography.X509Certificates; -using System.Threading; -using MediaBrowser.Model.Cryptography; -using MediaBrowser.Model.IO; -using MediaBrowser.Model.Logging; -using MediaBrowser.Model.Net; -using MediaBrowser.Model.System; -using MediaBrowser.Model.Text; -using SocketHttpListener.Primitives; -using ProtocolType = MediaBrowser.Model.Net.ProtocolType; -using SocketType = MediaBrowser.Model.Net.SocketType; - -namespace SocketHttpListener.Net -{ - sealed class EndPointListener - { - HttpListener listener; - IPEndPoint endpoint; - Socket sock; - Dictionary<ListenerPrefix, HttpListener> prefixes; // Dictionary <ListenerPrefix, HttpListener> - List<ListenerPrefix> unhandled; // List<ListenerPrefix> unhandled; host = '*' - List<ListenerPrefix> all; // List<ListenerPrefix> all; host = '+' - X509Certificate cert; - bool secure; - Dictionary<HttpConnection, HttpConnection> unregistered; - private readonly ILogger _logger; - private bool _closed; - private bool _enableDualMode; - private readonly ICryptoProvider _cryptoProvider; - private readonly ISocketFactory _socketFactory; - private readonly ITextEncoding _textEncoding; - private readonly IMemoryStreamFactory _memoryStreamFactory; - private readonly IFileSystem _fileSystem; - private readonly IEnvironmentInfo _environment; - - public EndPointListener(HttpListener listener, IPAddress addr, int port, bool secure, X509Certificate cert, ILogger logger, ICryptoProvider cryptoProvider, ISocketFactory socketFactory, IMemoryStreamFactory memoryStreamFactory, ITextEncoding textEncoding, IFileSystem fileSystem, IEnvironmentInfo environment) - { - this.listener = listener; - _logger = logger; - _cryptoProvider = cryptoProvider; - _socketFactory = socketFactory; - _memoryStreamFactory = memoryStreamFactory; - _textEncoding = textEncoding; - _fileSystem = fileSystem; - _environment = environment; - - this.secure = secure; - this.cert = cert; - - _enableDualMode = addr.Equals(IPAddress.IPv6Any); - endpoint = new IPEndPoint(addr, port); - - prefixes = new Dictionary<ListenerPrefix, HttpListener>(); - unregistered = new Dictionary<HttpConnection, HttpConnection>(); - - CreateSocket(); - } - - internal HttpListener Listener - { - get - { - return listener; - } - } - - private void CreateSocket() - { - try - { - sock = CreateSocket(endpoint.Address.AddressFamily, _enableDualMode); - } - catch (SocketCreateException ex) - { - if (_enableDualMode && endpoint.Address.Equals(IPAddress.IPv6Any) && - (string.Equals(ex.ErrorCode, "AddressFamilyNotSupported", StringComparison.OrdinalIgnoreCase) || - // mono on bsd is throwing this - string.Equals(ex.ErrorCode, "ProtocolNotSupported", StringComparison.OrdinalIgnoreCase))) - { - endpoint = new IPEndPoint(IPAddress.Any, endpoint.Port); - _enableDualMode = false; - sock = CreateSocket(endpoint.Address.AddressFamily, _enableDualMode); - } - else - { - throw; - } - } - - try - { - sock.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true); - } - catch (SocketException) - { - // This is not supported on all operating systems (qnap) - } - - sock.Bind(endpoint); - - // This is the number TcpListener uses. - sock.Listen(2147483647); - - new SocketAcceptor(_logger, sock, ProcessAccept, () => _closed).StartAccept(); - _closed = false; - } - - private Socket CreateSocket(AddressFamily addressFamily, bool dualMode) - { - try - { - var socket = new Socket(addressFamily, System.Net.Sockets.SocketType.Stream, System.Net.Sockets.ProtocolType.Tcp); - - if (dualMode) - { - socket.DualMode = true; - } - - return socket; - } - catch (SocketException ex) - { - throw new SocketCreateException(ex.SocketErrorCode.ToString(), ex); - } - catch (ArgumentException ex) - { - if (dualMode) - { - // Mono for BSD incorrectly throws ArgumentException instead of SocketException - throw new SocketCreateException("AddressFamilyNotSupported", ex); - } - else - { - throw; - } - } - } - - private async void ProcessAccept(Socket accepted) - { - try - { - var listener = this; - - if (listener.secure && listener.cert == null) - { - accepted.Close(); - return; - } - - HttpConnection conn = await HttpConnection.Create(_logger, accepted, listener, listener.secure, listener.cert, _cryptoProvider, _memoryStreamFactory, _textEncoding, _fileSystem, _environment).ConfigureAwait(false); - - //_logger.Debug("Adding unregistered connection to {0}. Id: {1}", accepted.RemoteEndPoint, connectionId); - lock (listener.unregistered) - { - listener.unregistered[conn] = conn; - } - conn.BeginReadRequest(); - } - catch (Exception ex) - { - _logger.ErrorException("Error in ProcessAccept", ex); - } - } - - internal void RemoveConnection(HttpConnection conn) - { - lock (unregistered) - { - unregistered.Remove(conn); - } - } - - public bool BindContext(HttpListenerContext context) - { - HttpListenerRequest req = context.Request; - ListenerPrefix prefix; - HttpListener listener = SearchListener(req.Url, out prefix); - if (listener == null) - return false; - - context.Connection.Prefix = prefix; - return true; - } - - public void UnbindContext(HttpListenerContext context) - { - if (context == null || context.Request == null) - return; - - listener.UnregisterContext(context); - } - - HttpListener SearchListener(Uri uri, out ListenerPrefix prefix) - { - prefix = null; - if (uri == null) - return null; - - string host = uri.Host; - int port = uri.Port; - string path = WebUtility.UrlDecode(uri.AbsolutePath); - string path_slash = path[path.Length - 1] == '/' ? path : path + "/"; - - HttpListener best_match = null; - int best_length = -1; - - if (host != null && host != "") - { - var p_ro = prefixes; - foreach (ListenerPrefix p in p_ro.Keys) - { - string ppath = p.Path; - if (ppath.Length < best_length) - continue; - - if (p.Host != host || p.Port != port) - continue; - - if (path.StartsWith(ppath) || path_slash.StartsWith(ppath)) - { - best_length = ppath.Length; - best_match = (HttpListener)p_ro[p]; - prefix = p; - } - } - if (best_length != -1) - return best_match; - } - - List<ListenerPrefix> list = unhandled; - best_match = MatchFromList(host, path, list, out prefix); - if (path != path_slash && best_match == null) - best_match = MatchFromList(host, path_slash, list, out prefix); - if (best_match != null) - return best_match; - - list = all; - best_match = MatchFromList(host, path, list, out prefix); - if (path != path_slash && best_match == null) - best_match = MatchFromList(host, path_slash, list, out prefix); - if (best_match != null) - return best_match; - - return null; - } - - HttpListener MatchFromList(string host, string path, List<ListenerPrefix> list, out ListenerPrefix prefix) - { - prefix = null; - if (list == null) - return null; - - HttpListener best_match = null; - int best_length = -1; - - foreach (ListenerPrefix p in list) - { - string ppath = p.Path; - if (ppath.Length < best_length) - continue; - - if (path.StartsWith(ppath)) - { - best_length = ppath.Length; - best_match = p.Listener; - prefix = p; - } - } - - return best_match; - } - - void AddSpecial(List<ListenerPrefix> coll, ListenerPrefix prefix) - { - if (coll == null) - return; - - foreach (ListenerPrefix p in coll) - { - if (p.Path == prefix.Path) //TODO: code - throw new HttpListenerException(400, "Prefix already in use."); - } - coll.Add(prefix); - } - - bool RemoveSpecial(List<ListenerPrefix> coll, ListenerPrefix prefix) - { - if (coll == null) - return false; - - int c = coll.Count; - for (int i = 0; i < c; i++) - { - ListenerPrefix p = (ListenerPrefix)coll[i]; - if (p.Path == prefix.Path) - { - coll.RemoveAt(i); - return true; - } - } - return false; - } - - void CheckIfRemove() - { - if (prefixes.Count > 0) - return; - - List<ListenerPrefix> list = unhandled; - if (list != null && list.Count > 0) - return; - - list = all; - if (list != null && list.Count > 0) - return; - - EndPointManager.RemoveEndPoint(this, endpoint); - } - - public void Close() - { - _closed = true; - sock.Close(); - lock (unregistered) - { - // - // Clone the list because RemoveConnection can be called from Close - // - var connections = new List<HttpConnection>(unregistered.Keys); - - foreach (HttpConnection c in connections) - c.Close(true); - unregistered.Clear(); - } - } - - public void AddPrefix(ListenerPrefix prefix, HttpListener listener) - { - List<ListenerPrefix> current; - List<ListenerPrefix> future; - if (prefix.Host == "*") - { - do - { - current = unhandled; - future = (current != null) ? current.ToList() : new List<ListenerPrefix>(); - prefix.Listener = listener; - AddSpecial(future, prefix); - } while (Interlocked.CompareExchange(ref unhandled, future, current) != current); - return; - } - - if (prefix.Host == "+") - { - do - { - current = all; - future = (current != null) ? current.ToList() : new List<ListenerPrefix>(); - prefix.Listener = listener; - AddSpecial(future, prefix); - } while (Interlocked.CompareExchange(ref all, future, current) != current); - return; - } - - Dictionary<ListenerPrefix, HttpListener> prefs; - Dictionary<ListenerPrefix, HttpListener> p2; - do - { - prefs = prefixes; - if (prefs.ContainsKey(prefix)) - { - HttpListener other = (HttpListener)prefs[prefix]; - if (other != listener) // TODO: code. - throw new HttpListenerException(400, "There's another listener for " + prefix); - return; - } - p2 = new Dictionary<ListenerPrefix, HttpListener>(prefs); - p2[prefix] = listener; - } while (Interlocked.CompareExchange(ref prefixes, p2, prefs) != prefs); - } - - public void RemovePrefix(ListenerPrefix prefix, HttpListener listener) - { - List<ListenerPrefix> current; - List<ListenerPrefix> future; - if (prefix.Host == "*") - { - do - { - current = unhandled; - future = (current != null) ? current.ToList() : new List<ListenerPrefix>(); - if (!RemoveSpecial(future, prefix)) - break; // Prefix not found - } while (Interlocked.CompareExchange(ref unhandled, future, current) != current); - CheckIfRemove(); - return; - } - - if (prefix.Host == "+") - { - do - { - current = all; - future = (current != null) ? current.ToList() : new List<ListenerPrefix>(); - if (!RemoveSpecial(future, prefix)) - break; // Prefix not found - } while (Interlocked.CompareExchange(ref all, future, current) != current); - CheckIfRemove(); - return; - } - - Dictionary<ListenerPrefix, HttpListener> prefs; - Dictionary<ListenerPrefix, HttpListener> p2; - do - { - prefs = prefixes; - if (!prefs.ContainsKey(prefix)) - break; - - p2 = new Dictionary<ListenerPrefix, HttpListener>(prefs); - p2.Remove(prefix); - } while (Interlocked.CompareExchange(ref prefixes, p2, prefs) != prefs); - CheckIfRemove(); - } - } -} diff --git a/SocketHttpListener/Net/EndPointManager.cs b/SocketHttpListener/Net/EndPointManager.cs deleted file mode 100644 index 557caa59a..000000000 --- a/SocketHttpListener/Net/EndPointManager.cs +++ /dev/null @@ -1,167 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Net; -using System.Reflection; -using System.Threading.Tasks; -using MediaBrowser.Model.IO; -using MediaBrowser.Model.Logging; -using MediaBrowser.Model.Net; -using SocketHttpListener.Primitives; - -namespace SocketHttpListener.Net -{ - sealed class EndPointManager - { - // Dictionary<IPAddress, Dictionary<int, EndPointListener>> - static Dictionary<string, Dictionary<int, EndPointListener>> ip_to_endpoints = new Dictionary<string, Dictionary<int, EndPointListener>>(); - - private EndPointManager() - { - } - - public static void AddListener(ILogger logger, HttpListener listener) - { - List<string> added = new List<string>(); - try - { - lock (ip_to_endpoints) - { - foreach (string prefix in listener.Prefixes) - { - AddPrefixInternal(logger, prefix, listener); - added.Add(prefix); - } - } - } - catch - { - foreach (string prefix in added) - { - RemovePrefix(logger, prefix, listener); - } - throw; - } - } - - public static void AddPrefix(ILogger logger, string prefix, HttpListener listener) - { - lock (ip_to_endpoints) - { - AddPrefixInternal(logger, prefix, listener); - } - } - - static void AddPrefixInternal(ILogger logger, string p, HttpListener listener) - { - ListenerPrefix lp = new ListenerPrefix(p); - if (lp.Path.IndexOf('%') != -1) - throw new HttpListenerException(400, "Invalid path."); - - if (lp.Path.IndexOf("//", StringComparison.Ordinal) != -1) // TODO: Code? - throw new HttpListenerException(400, "Invalid path."); - - // listens on all the interfaces if host name cannot be parsed by IPAddress. - EndPointListener epl = GetEPListener(logger, lp.Host, lp.Port, listener, lp.Secure).Result; - epl.AddPrefix(lp, listener); - } - - private static IPAddress GetIpAnyAddress(HttpListener listener) - { - return listener.EnableDualMode ? IPAddress.IPv6Any : IPAddress.Any; - } - - static async Task<EndPointListener> GetEPListener(ILogger logger, string host, int port, HttpListener listener, bool secure) - { - var networkManager = listener.NetworkManager; - - IPAddress addr; - if (host == "*" || host == "+") - addr = GetIpAnyAddress(listener); - else if (IPAddress.TryParse(host, out addr) == false) - { - try - { - var all = (await networkManager.GetHostAddressesAsync(host).ConfigureAwait(false)); - - addr = (all.Length == 0 ? null : IPAddress.Parse(all[0].Address)) ?? - GetIpAnyAddress(listener); - } - catch - { - addr = GetIpAnyAddress(listener); - } - } - - Dictionary<int, EndPointListener> p = null; // Dictionary<int, EndPointListener> - if (!ip_to_endpoints.TryGetValue(addr.ToString(), out p)) - { - p = new Dictionary<int, EndPointListener>(); - ip_to_endpoints[addr.ToString()] = p; - } - - EndPointListener epl = null; - if (p.ContainsKey(port)) - { - epl = (EndPointListener)p[port]; - } - else - { - epl = new EndPointListener(listener, addr, port, secure, listener.Certificate, logger, listener.CryptoProvider, listener.SocketFactory, listener.MemoryStreamFactory, listener.TextEncoding, listener.FileSystem, listener.EnvironmentInfo); - p[port] = epl; - } - - return epl; - } - - public static void RemoveEndPoint(EndPointListener epl, IPEndPoint ep) - { - lock (ip_to_endpoints) - { - // Dictionary<int, EndPointListener> p - Dictionary<int, EndPointListener> p; - if (ip_to_endpoints.TryGetValue(ep.Address.ToString(), out p)) - { - p.Remove(ep.Port); - if (p.Count == 0) - { - ip_to_endpoints.Remove(ep.Address.ToString()); - } - } - epl.Close(); - } - } - - public static void RemoveListener(ILogger logger, HttpListener listener) - { - lock (ip_to_endpoints) - { - foreach (string prefix in listener.Prefixes) - { - RemovePrefixInternal(logger, prefix, listener); - } - } - } - - public static void RemovePrefix(ILogger logger, string prefix, HttpListener listener) - { - lock (ip_to_endpoints) - { - RemovePrefixInternal(logger, prefix, listener); - } - } - - static void RemovePrefixInternal(ILogger logger, string prefix, HttpListener listener) - { - ListenerPrefix lp = new ListenerPrefix(prefix); - if (lp.Path.IndexOf('%') != -1) - return; - - if (lp.Path.IndexOf("//", StringComparison.Ordinal) != -1) - return; - - EndPointListener epl = GetEPListener(logger, lp.Host, lp.Port, listener, lp.Secure).Result; - epl.RemovePrefix(lp, listener); - } - } -} diff --git a/SocketHttpListener/Net/HttpConnection.cs b/SocketHttpListener/Net/HttpConnection.cs index 05576ea1e..9b4fb8705 100644 --- a/SocketHttpListener/Net/HttpConnection.cs +++ b/SocketHttpListener/Net/HttpConnection.cs @@ -13,7 +13,9 @@ using MediaBrowser.Model.Net; using MediaBrowser.Model.System; using MediaBrowser.Model.Text; using SocketHttpListener.Primitives; +using System.Security.Authentication; +using System.Threading; namespace SocketHttpListener.Net { sealed class HttpConnection @@ -22,7 +24,7 @@ namespace SocketHttpListener.Net const int BufferSize = 8192; Socket _socket; Stream _stream; - EndPointListener _epl; + HttpEndPointListener _epl; MemoryStream _memoryStream; byte[] _buffer; HttpListenerContext _context; @@ -34,21 +36,21 @@ namespace SocketHttpListener.Net int _reuses; bool _contextBound; bool secure; - int _timeout = 300000; // 90k ms for first request, 15k ms from then on + int _timeout = 90000; // 90k ms for first request, 15k ms from then on + private Timer _timer; IPEndPoint local_ep; HttpListener _lastListener; - int[] client_cert_errors; X509Certificate cert; SslStream ssl_stream; private readonly ILogger _logger; private readonly ICryptoProvider _cryptoProvider; - private readonly IMemoryStreamFactory _memoryStreamFactory; + private readonly IStreamHelper _streamHelper; private readonly ITextEncoding _textEncoding; private readonly IFileSystem _fileSystem; private readonly IEnvironmentInfo _environment; - private HttpConnection(ILogger logger, Socket socket, EndPointListener epl, bool secure, X509Certificate cert, ICryptoProvider cryptoProvider, IMemoryStreamFactory memoryStreamFactory, ITextEncoding textEncoding, IFileSystem fileSystem, IEnvironmentInfo environment) + public HttpConnection(ILogger logger, Socket socket, HttpEndPointListener epl, bool secure, X509Certificate cert, ICryptoProvider cryptoProvider, IStreamHelper streamHelper, ITextEncoding textEncoding, IFileSystem fileSystem, IEnvironmentInfo environment) { _logger = logger; this._socket = socket; @@ -56,47 +58,37 @@ namespace SocketHttpListener.Net this.secure = secure; this.cert = cert; _cryptoProvider = cryptoProvider; - _memoryStreamFactory = memoryStreamFactory; + _streamHelper = streamHelper; _textEncoding = textEncoding; _fileSystem = fileSystem; _environment = environment; - } - private async Task InitStream() - { if (secure == false) { _stream = new SocketStream(_socket, false); } else { - //ssl_stream = _epl.Listener.CreateSslStream(new NetworkStream(_socket, false), false, (t, c, ch, e) => - //{ - // if (c == null) - // return true; - // var c2 = c as X509Certificate2; - // if (c2 == null) - // c2 = new X509Certificate2(c.GetRawCertData()); - // client_cert = c2; - // client_cert_errors = new int[] { (int)e }; - // return true; - //}); - //_stream = ssl_stream.AuthenticatedStream; - - ssl_stream = new SslStream(new SocketStream(_socket, false), false); - await ssl_stream.AuthenticateAsServerAsync(cert).ConfigureAwait(false); - _stream = ssl_stream; - } - Init(); - } + ssl_stream = new SslStream(new SocketStream(_socket, false), false, (t, c, ch, e) => + { + if (c == null) + { + return true; + } - public static async Task<HttpConnection> Create(ILogger logger, Socket sock, EndPointListener epl, bool secure, X509Certificate cert, ICryptoProvider cryptoProvider, IMemoryStreamFactory memoryStreamFactory, ITextEncoding textEncoding, IFileSystem fileSystem, IEnvironmentInfo environment) - { - var connection = new HttpConnection(logger, sock, epl, secure, cert, cryptoProvider, memoryStreamFactory, textEncoding, fileSystem, environment); + //var c2 = c as X509Certificate2; + //if (c2 == null) + //{ + // c2 = new X509Certificate2(c.GetRawCertData()); + //} - await connection.InitStream().ConfigureAwait(false); + //_clientCert = c2; + //_clientCertErrors = new int[] { (int)e }; + return true; + }); - return connection; + _stream = ssl_stream; + } } public Stream Stream @@ -107,12 +99,27 @@ namespace SocketHttpListener.Net } } - internal int[] ClientCertificateErrors + public async Task Init() { - get { return client_cert_errors; } + _timer = new Timer(OnTimeout, null, Timeout.Infinite, Timeout.Infinite); + + if (ssl_stream != null) + { + var enableAsync = true; + if (enableAsync) + { + await ssl_stream.AuthenticateAsServerAsync(cert, false, (SslProtocols)ServicePointManager.SecurityProtocol, false).ConfigureAwait(false); + } + else + { + ssl_stream.AuthenticateAsServer(cert, false, (SslProtocols)ServicePointManager.SecurityProtocol, false); + } + } + + InitInternal(); } - void Init() + private void InitInternal() { _contextBound = false; _requestStream = null; @@ -123,7 +130,7 @@ namespace SocketHttpListener.Net _position = 0; _inputState = InputState.RequestLine; _lineState = LineState.None; - _context = new HttpListenerContext(this, _logger, _cryptoProvider, _memoryStreamFactory, _textEncoding, _fileSystem); + _context = new HttpListenerContext(this, _textEncoding); } public bool IsClosed @@ -164,6 +171,13 @@ namespace SocketHttpListener.Net set { _prefix = value; } } + private void OnTimeout(object unused) + { + //_logger.Info("HttpConnection timer fired"); + CloseSocket(); + Unbind(); + } + public void BeginReadRequest() { if (_buffer == null) @@ -211,7 +225,7 @@ namespace SocketHttpListener.Net { var supportsDirectSocketAccess = !_context.Response.SendChunked && !isExpect100Continue && !secure; - _responseStream = new HttpResponseStream(_stream, _context.Response, false, _memoryStreamFactory, _socket, supportsDirectSocketAccess, _environment, _fileSystem, _logger); + _responseStream = new HttpResponseStream(_stream, _context.Response, false, _streamHelper, _socket, supportsDirectSocketAccess, _environment, _fileSystem, _logger); } return _responseStream; } @@ -503,14 +517,14 @@ namespace SocketHttpListener.Net // Don't close. Keep working. _reuses++; Unbind(); - Init(); + InitInternal(); BeginReadRequest(); return; } _reuses++; Unbind(); - Init(); + InitInternal(); BeginReadRequest(); return; } diff --git a/SocketHttpListener/Net/HttpEndPointListener.cs b/SocketHttpListener/Net/HttpEndPointListener.cs new file mode 100644 index 000000000..254e76140 --- /dev/null +++ b/SocketHttpListener/Net/HttpEndPointListener.cs @@ -0,0 +1,539 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Net.Sockets; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using MediaBrowser.Model.Cryptography; +using MediaBrowser.Model.IO; +using MediaBrowser.Model.Logging; +using MediaBrowser.Model.Net; +using MediaBrowser.Model.System; +using MediaBrowser.Model.Text; +using SocketHttpListener.Primitives; +using ProtocolType = MediaBrowser.Model.Net.ProtocolType; +using SocketType = MediaBrowser.Model.Net.SocketType; +using System.Threading.Tasks; + +namespace SocketHttpListener.Net +{ + internal sealed class HttpEndPointListener + { + private HttpListener _listener; + private IPEndPoint _endpoint; + private Socket _socket; + private Dictionary<ListenerPrefix, HttpListener> _prefixes; + private List<ListenerPrefix> _unhandledPrefixes; // host = '*' + private List<ListenerPrefix> _allPrefixes; // host = '+' + private X509Certificate _cert; + private bool _secure; + private Dictionary<HttpConnection, HttpConnection> _unregisteredConnections; + + private readonly ILogger _logger; + private bool _closed; + private bool _enableDualMode; + private readonly ICryptoProvider _cryptoProvider; + private readonly ISocketFactory _socketFactory; + private readonly ITextEncoding _textEncoding; + private readonly IStreamHelper _streamHelper; + private readonly IFileSystem _fileSystem; + private readonly IEnvironmentInfo _environment; + + public HttpEndPointListener(HttpListener listener, IPAddress addr, int port, bool secure, X509Certificate cert, ILogger logger, ICryptoProvider cryptoProvider, ISocketFactory socketFactory, IStreamHelper streamHelper, ITextEncoding textEncoding, IFileSystem fileSystem, IEnvironmentInfo environment) + { + this._listener = listener; + _logger = logger; + _cryptoProvider = cryptoProvider; + _socketFactory = socketFactory; + _streamHelper = streamHelper; + _textEncoding = textEncoding; + _fileSystem = fileSystem; + _environment = environment; + + this._secure = secure; + this._cert = cert; + + _enableDualMode = addr.Equals(IPAddress.IPv6Any); + _endpoint = new IPEndPoint(addr, port); + + _prefixes = new Dictionary<ListenerPrefix, HttpListener>(); + _unregisteredConnections = new Dictionary<HttpConnection, HttpConnection>(); + + CreateSocket(); + } + + internal HttpListener Listener + { + get + { + return _listener; + } + } + + private void CreateSocket() + { + try + { + _socket = CreateSocket(_endpoint.Address.AddressFamily, _enableDualMode); + } + catch (SocketCreateException ex) + { + if (_enableDualMode && _endpoint.Address.Equals(IPAddress.IPv6Any) && + (string.Equals(ex.ErrorCode, "AddressFamilyNotSupported", StringComparison.OrdinalIgnoreCase) || + // mono 4.8.1 and lower on bsd is throwing this + string.Equals(ex.ErrorCode, "ProtocolNotSupported", StringComparison.OrdinalIgnoreCase) || + // mono 5.2 on bsd is throwing this + string.Equals(ex.ErrorCode, "OperationNotSupported", StringComparison.OrdinalIgnoreCase))) + { + _endpoint = new IPEndPoint(IPAddress.Any, _endpoint.Port); + _enableDualMode = false; + _socket = CreateSocket(_endpoint.Address.AddressFamily, _enableDualMode); + } + else + { + throw; + } + } + + try + { + _socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true); + } + catch (SocketException) + { + // This is not supported on all operating systems (qnap) + } + + _socket.Bind(_endpoint); + + // This is the number TcpListener uses. + _socket.Listen(2147483647); + + Accept(); + + _closed = false; + } + + private void Accept() + { + var acceptEventArg = new SocketAsyncEventArgs(); + acceptEventArg.UserToken = this; + acceptEventArg.Completed += new EventHandler<SocketAsyncEventArgs>(OnAccept); + + Accept(acceptEventArg); + } + + private static void TryCloseAndDispose(Socket socket) + { + try + { + using (socket) + { + socket.Close(); + } + } + catch + { + + } + } + + private static void TryClose(Socket socket) + { + try + { + socket.Close(); + } + catch + { + + } + } + + private void Accept(SocketAsyncEventArgs acceptEventArg) + { + // acceptSocket must be cleared since the context object is being reused + acceptEventArg.AcceptSocket = null; + + try + { + bool willRaiseEvent = _socket.AcceptAsync(acceptEventArg); + + if (!willRaiseEvent) + { + ProcessAccept(acceptEventArg); + } + } + catch (ObjectDisposedException) + { + } + catch (Exception ex) + { + HttpEndPointListener epl = (HttpEndPointListener)acceptEventArg.UserToken; + + epl._logger.ErrorException("Error in socket.AcceptAsync", ex); + } + } + + // This method is the callback method associated with Socket.AcceptAsync + // operations and is invoked when an accept operation is complete + // + private static void OnAccept(object sender, SocketAsyncEventArgs e) + { + ProcessAccept(e); + } + + private static async void ProcessAccept(SocketAsyncEventArgs args) + { + HttpEndPointListener epl = (HttpEndPointListener)args.UserToken; + + if (epl._closed) + { + return; + } + + // http://msdn.microsoft.com/en-us/library/system.net.sockets.acceptSocket.acceptasync%28v=vs.110%29.aspx + // Under certain conditions ConnectionReset can occur + // Need to attept to re-accept + var socketError = args.SocketError; + var accepted = args.AcceptSocket; + + epl.Accept(args); + + if (socketError == SocketError.ConnectionReset) + { + epl._logger.Error("SocketError.ConnectionReset reported. Attempting to re-accept."); + return; + } + + if(accepted == null) + { + return; + } + + if (epl._secure && epl._cert == null) + { + TryClose(accepted); + return; + } + + try + { + var remoteEndPointString = accepted.RemoteEndPoint == null ? string.Empty : accepted.RemoteEndPoint.ToString(); + var localEndPointString = accepted.LocalEndPoint == null ? string.Empty : accepted.LocalEndPoint.ToString(); + //_logger.Info("HttpEndPointListener Accepting connection from {0} to {1} secure connection requested: {2}", remoteEndPointString, localEndPointString, _secure); + + HttpConnection conn = new HttpConnection(epl._logger, accepted, epl, epl._secure, epl._cert, epl._cryptoProvider, epl._streamHelper, epl._textEncoding, epl._fileSystem, epl._environment); + + await conn.Init().ConfigureAwait(false); + + //_logger.Debug("Adding unregistered connection to {0}. Id: {1}", accepted.RemoteEndPoint, connectionId); + lock (epl._unregisteredConnections) + { + epl._unregisteredConnections[conn] = conn; + } + conn.BeginReadRequest(); + } + catch (Exception ex) + { + epl._logger.ErrorException("Error in ProcessAccept", ex); + + TryClose(accepted); + epl.Accept(); + return; + } + } + + private Socket CreateSocket(AddressFamily addressFamily, bool dualMode) + { + try + { + var socket = new Socket(addressFamily, System.Net.Sockets.SocketType.Stream, System.Net.Sockets.ProtocolType.Tcp); + + if (dualMode) + { + socket.DualMode = true; + } + + return socket; + } + catch (SocketException ex) + { + throw new SocketCreateException(ex.SocketErrorCode.ToString(), ex); + } + catch (ArgumentException ex) + { + if (dualMode) + { + // Mono for BSD incorrectly throws ArgumentException instead of SocketException + throw new SocketCreateException("AddressFamilyNotSupported", ex); + } + else + { + throw; + } + } + } + + internal void RemoveConnection(HttpConnection conn) + { + lock (_unregisteredConnections) + { + _unregisteredConnections.Remove(conn); + } + } + + public bool BindContext(HttpListenerContext context) + { + HttpListenerRequest req = context.Request; + ListenerPrefix prefix; + HttpListener listener = SearchListener(req.Url, out prefix); + if (listener == null) + return false; + + context.Connection.Prefix = prefix; + return true; + } + + public void UnbindContext(HttpListenerContext context) + { + if (context == null || context.Request == null) + return; + + _listener.UnregisterContext(context); + } + + private HttpListener SearchListener(Uri uri, out ListenerPrefix prefix) + { + prefix = null; + if (uri == null) + return null; + + string host = uri.Host; + int port = uri.Port; + string path = WebUtility.UrlDecode(uri.AbsolutePath); + string pathSlash = path[path.Length - 1] == '/' ? path : path + "/"; + + HttpListener bestMatch = null; + int bestLength = -1; + + if (host != null && host != "") + { + Dictionary<ListenerPrefix, HttpListener> localPrefixes = _prefixes; + foreach (ListenerPrefix p in localPrefixes.Keys) + { + string ppath = p.Path; + if (ppath.Length < bestLength) + continue; + + if (p.Host != host || p.Port != port) + continue; + + if (path.StartsWith(ppath) || pathSlash.StartsWith(ppath)) + { + bestLength = ppath.Length; + bestMatch = localPrefixes[p]; + prefix = p; + } + } + if (bestLength != -1) + return bestMatch; + } + + List<ListenerPrefix> list = _unhandledPrefixes; + bestMatch = MatchFromList(host, path, list, out prefix); + + if (path != pathSlash && bestMatch == null) + bestMatch = MatchFromList(host, pathSlash, list, out prefix); + + if (bestMatch != null) + return bestMatch; + + list = _allPrefixes; + bestMatch = MatchFromList(host, path, list, out prefix); + + if (path != pathSlash && bestMatch == null) + bestMatch = MatchFromList(host, pathSlash, list, out prefix); + + if (bestMatch != null) + return bestMatch; + + return null; + } + + private HttpListener MatchFromList(string host, string path, List<ListenerPrefix> list, out ListenerPrefix prefix) + { + prefix = null; + if (list == null) + return null; + + HttpListener bestMatch = null; + int bestLength = -1; + + foreach (ListenerPrefix p in list) + { + string ppath = p.Path; + if (ppath.Length < bestLength) + continue; + + if (path.StartsWith(ppath)) + { + bestLength = ppath.Length; + bestMatch = p._listener; + prefix = p; + } + } + + return bestMatch; + } + + private void AddSpecial(List<ListenerPrefix> list, ListenerPrefix prefix) + { + if (list == null) + return; + + foreach (ListenerPrefix p in list) + { + if (p.Path == prefix.Path) + throw new Exception("net_listener_already"); + } + list.Add(prefix); + } + + private bool RemoveSpecial(List<ListenerPrefix> list, ListenerPrefix prefix) + { + if (list == null) + return false; + + int c = list.Count; + for (int i = 0; i < c; i++) + { + ListenerPrefix p = list[i]; + if (p.Path == prefix.Path) + { + list.RemoveAt(i); + return true; + } + } + return false; + } + + private void CheckIfRemove() + { + if (_prefixes.Count > 0) + return; + + List<ListenerPrefix> list = _unhandledPrefixes; + if (list != null && list.Count > 0) + return; + + list = _allPrefixes; + if (list != null && list.Count > 0) + return; + + HttpEndPointManager.RemoveEndPoint(this, _endpoint); + } + + public void Close() + { + _closed = true; + _socket.Close(); + lock (_unregisteredConnections) + { + // Clone the list because RemoveConnection can be called from Close + var connections = new List<HttpConnection>(_unregisteredConnections.Keys); + + foreach (HttpConnection c in connections) + c.Close(true); + _unregisteredConnections.Clear(); + } + } + + public void AddPrefix(ListenerPrefix prefix, HttpListener listener) + { + List<ListenerPrefix> current; + List<ListenerPrefix> future; + if (prefix.Host == "*") + { + do + { + current = _unhandledPrefixes; + future = current != null ? new List<ListenerPrefix>(current) : new List<ListenerPrefix>(); + prefix._listener = listener; + AddSpecial(future, prefix); + } while (Interlocked.CompareExchange(ref _unhandledPrefixes, future, current) != current); + return; + } + + if (prefix.Host == "+") + { + do + { + current = _allPrefixes; + future = current != null ? new List<ListenerPrefix>(current) : new List<ListenerPrefix>(); + prefix._listener = listener; + AddSpecial(future, prefix); + } while (Interlocked.CompareExchange(ref _allPrefixes, future, current) != current); + return; + } + + Dictionary<ListenerPrefix, HttpListener> prefs, p2; + do + { + prefs = _prefixes; + if (prefs.ContainsKey(prefix)) + { + throw new Exception("net_listener_already"); + } + p2 = new Dictionary<ListenerPrefix, HttpListener>(prefs); + p2[prefix] = listener; + } while (Interlocked.CompareExchange(ref _prefixes, p2, prefs) != prefs); + } + + public void RemovePrefix(ListenerPrefix prefix, HttpListener listener) + { + List<ListenerPrefix> current; + List<ListenerPrefix> future; + if (prefix.Host == "*") + { + do + { + current = _unhandledPrefixes; + future = current != null ? new List<ListenerPrefix>(current) : new List<ListenerPrefix>(); + if (!RemoveSpecial(future, prefix)) + break; // Prefix not found + } while (Interlocked.CompareExchange(ref _unhandledPrefixes, future, current) != current); + + CheckIfRemove(); + return; + } + + if (prefix.Host == "+") + { + do + { + current = _allPrefixes; + future = current != null ? new List<ListenerPrefix>(current) : new List<ListenerPrefix>(); + if (!RemoveSpecial(future, prefix)) + break; // Prefix not found + } while (Interlocked.CompareExchange(ref _allPrefixes, future, current) != current); + CheckIfRemove(); + return; + } + + Dictionary<ListenerPrefix, HttpListener> prefs, p2; + do + { + prefs = _prefixes; + if (!prefs.ContainsKey(prefix)) + break; + + p2 = new Dictionary<ListenerPrefix, HttpListener>(prefs); + p2.Remove(prefix); + } while (Interlocked.CompareExchange(ref _prefixes, p2, prefs) != prefs); + CheckIfRemove(); + } + } +} diff --git a/SocketHttpListener/Net/HttpEndPointManager.cs b/SocketHttpListener/Net/HttpEndPointManager.cs new file mode 100644 index 000000000..45af92c01 --- /dev/null +++ b/SocketHttpListener/Net/HttpEndPointManager.cs @@ -0,0 +1,198 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Net; +using System.Net.Sockets; +using System.Reflection; +using System.Threading.Tasks; +using MediaBrowser.Model.IO; +using MediaBrowser.Model.Logging; +using MediaBrowser.Model.Net; +using SocketHttpListener.Primitives; + +namespace SocketHttpListener.Net +{ + internal sealed class HttpEndPointManager + { + private static Dictionary<IPAddress, Dictionary<int, HttpEndPointListener>> s_ipEndPoints = new Dictionary<IPAddress, Dictionary<int, HttpEndPointListener>>(); + + private HttpEndPointManager() + { + } + + public static void AddListener(ILogger logger, HttpListener listener) + { + List<string> added = new List<string>(); + try + { + lock ((s_ipEndPoints as ICollection).SyncRoot) + { + foreach (string prefix in listener.Prefixes) + { + AddPrefixInternal(logger, prefix, listener); + added.Add(prefix); + } + } + } + catch + { + foreach (string prefix in added) + { + RemovePrefix(logger, prefix, listener); + } + throw; + } + } + + public static void AddPrefix(ILogger logger, string prefix, HttpListener listener) + { + lock ((s_ipEndPoints as ICollection).SyncRoot) + { + AddPrefixInternal(logger, prefix, listener); + } + } + + private static void AddPrefixInternal(ILogger logger, string p, HttpListener listener) + { + int start = p.IndexOf(':') + 3; + int colon = p.IndexOf(':', start); + if (colon != -1) + { + // root can't be -1 here, since we've already checked for ending '/' in ListenerPrefix. + int root = p.IndexOf('/', colon, p.Length - colon); + string portString = p.Substring(colon + 1, root - colon - 1); + + int port; + if (!int.TryParse(portString, out port) || port <= 0 || port >= 65536) + { + throw new HttpListenerException((int)HttpStatusCode.BadRequest, "net_invalid_port"); + } + } + + ListenerPrefix lp = new ListenerPrefix(p); + if (lp.Host != "*" && lp.Host != "+" && Uri.CheckHostName(lp.Host) == UriHostNameType.Unknown) + throw new HttpListenerException((int)HttpStatusCode.BadRequest, "net_listener_host"); + + if (lp.Path.IndexOf('%') != -1) + throw new HttpListenerException((int)HttpStatusCode.BadRequest, "net_invalid_path"); + + if (lp.Path.IndexOf("//", StringComparison.Ordinal) != -1) + throw new HttpListenerException((int)HttpStatusCode.BadRequest, "net_invalid_path"); + + // listens on all the interfaces if host name cannot be parsed by IPAddress. + HttpEndPointListener epl = GetEPListener(logger, lp.Host, lp.Port, listener, lp.Secure); + epl.AddPrefix(lp, listener); + } + + private static IPAddress GetIpAnyAddress(HttpListener listener) + { + return listener.EnableDualMode ? IPAddress.IPv6Any : IPAddress.Any; + } + + private static HttpEndPointListener GetEPListener(ILogger logger, string host, int port, HttpListener listener, bool secure) + { + IPAddress addr; + if (host == "*" || host == "+") + { + addr = GetIpAnyAddress(listener); + } + else + { + const int NotSupportedErrorCode = 50; + try + { + addr = Dns.GetHostAddresses(host)[0]; + } + catch + { + // Throw same error code as windows, request is not supported. + throw new HttpListenerException(NotSupportedErrorCode, "net_listener_not_supported"); + } + + if (IPAddress.Any.Equals(addr)) + { + // Don't support listening to 0.0.0.0, match windows behavior. + throw new HttpListenerException(NotSupportedErrorCode, "net_listener_not_supported"); + } + } + + Dictionary<int, HttpEndPointListener> p = null; + if (s_ipEndPoints.ContainsKey(addr)) + { + p = s_ipEndPoints[addr]; + } + else + { + p = new Dictionary<int, HttpEndPointListener>(); + s_ipEndPoints[addr] = p; + } + + HttpEndPointListener epl = null; + if (p.ContainsKey(port)) + { + epl = p[port]; + } + else + { + try + { + epl = new HttpEndPointListener(listener, addr, port, secure, listener.Certificate, logger, listener.CryptoProvider, listener.SocketFactory, listener.StreamHelper, listener.TextEncoding, listener.FileSystem, listener.EnvironmentInfo); + } + catch (SocketException ex) + { + throw new HttpListenerException(ex.ErrorCode, ex.Message); + } + p[port] = epl; + } + + return epl; + } + + public static void RemoveEndPoint(HttpEndPointListener epl, IPEndPoint ep) + { + lock ((s_ipEndPoints as ICollection).SyncRoot) + { + Dictionary<int, HttpEndPointListener> p = null; + p = s_ipEndPoints[ep.Address]; + p.Remove(ep.Port); + if (p.Count == 0) + { + s_ipEndPoints.Remove(ep.Address); + } + epl.Close(); + } + } + + public static void RemoveListener(ILogger logger, HttpListener listener) + { + lock ((s_ipEndPoints as ICollection).SyncRoot) + { + foreach (string prefix in listener.Prefixes) + { + RemovePrefixInternal(logger, prefix, listener); + } + } + } + + public static void RemovePrefix(ILogger logger, string prefix, HttpListener listener) + { + lock ((s_ipEndPoints as ICollection).SyncRoot) + { + RemovePrefixInternal(logger, prefix, listener); + } + } + + private static void RemovePrefixInternal(ILogger logger, string prefix, HttpListener listener) + { + ListenerPrefix lp = new ListenerPrefix(prefix); + if (lp.Path.IndexOf('%') != -1) + return; + + if (lp.Path.IndexOf("//", StringComparison.Ordinal) != -1) + return; + + HttpEndPointListener epl = GetEPListener(logger, lp.Host, lp.Port, listener, lp.Secure); + epl.RemovePrefix(lp, listener); + } + } +} diff --git a/SocketHttpListener/Net/HttpKnownHeaderNames.cs b/SocketHttpListener/Net/HttpKnownHeaderNames.cs new file mode 100644 index 000000000..ea4695850 --- /dev/null +++ b/SocketHttpListener/Net/HttpKnownHeaderNames.cs @@ -0,0 +1,95 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace SocketHttpListener.Net +{ + internal static partial class HttpKnownHeaderNames + { + // When adding a new constant, add it to HttpKnownHeaderNames.TryGetHeaderName.cs as well. + + public const string Accept = "Accept"; + public const string AcceptCharset = "Accept-Charset"; + public const string AcceptEncoding = "Accept-Encoding"; + public const string AcceptLanguage = "Accept-Language"; + public const string AcceptPatch = "Accept-Patch"; + public const string AcceptRanges = "Accept-Ranges"; + public const string AccessControlAllowCredentials = "Access-Control-Allow-Credentials"; + public const string AccessControlAllowHeaders = "Access-Control-Allow-Headers"; + public const string AccessControlAllowMethods = "Access-Control-Allow-Methods"; + public const string AccessControlAllowOrigin = "Access-Control-Allow-Origin"; + public const string AccessControlExposeHeaders = "Access-Control-Expose-Headers"; + public const string AccessControlMaxAge = "Access-Control-Max-Age"; + public const string Age = "Age"; + public const string Allow = "Allow"; + public const string AltSvc = "Alt-Svc"; + public const string Authorization = "Authorization"; + public const string CacheControl = "Cache-Control"; + public const string Connection = "Connection"; + public const string ContentDisposition = "Content-Disposition"; + public const string ContentEncoding = "Content-Encoding"; + public const string ContentLanguage = "Content-Language"; + public const string ContentLength = "Content-Length"; + public const string ContentLocation = "Content-Location"; + public const string ContentMD5 = "Content-MD5"; + public const string ContentRange = "Content-Range"; + public const string ContentSecurityPolicy = "Content-Security-Policy"; + public const string ContentType = "Content-Type"; + public const string Cookie = "Cookie"; + public const string Cookie2 = "Cookie2"; + public const string Date = "Date"; + public const string ETag = "ETag"; + public const string Expect = "Expect"; + public const string Expires = "Expires"; + public const string From = "From"; + public const string Host = "Host"; + public const string IfMatch = "If-Match"; + public const string IfModifiedSince = "If-Modified-Since"; + public const string IfNoneMatch = "If-None-Match"; + public const string IfRange = "If-Range"; + public const string IfUnmodifiedSince = "If-Unmodified-Since"; + public const string KeepAlive = "Keep-Alive"; + public const string LastModified = "Last-Modified"; + public const string Link = "Link"; + public const string Location = "Location"; + public const string MaxForwards = "Max-Forwards"; + public const string Origin = "Origin"; + public const string P3P = "P3P"; + public const string Pragma = "Pragma"; + public const string ProxyAuthenticate = "Proxy-Authenticate"; + public const string ProxyAuthorization = "Proxy-Authorization"; + public const string ProxyConnection = "Proxy-Connection"; + public const string PublicKeyPins = "Public-Key-Pins"; + public const string Range = "Range"; + public const string Referer = "Referer"; // NB: The spelling-mistake "Referer" for "Referrer" must be matched. + public const string RetryAfter = "Retry-After"; + public const string SecWebSocketAccept = "Sec-WebSocket-Accept"; + public const string SecWebSocketExtensions = "Sec-WebSocket-Extensions"; + public const string SecWebSocketKey = "Sec-WebSocket-Key"; + public const string SecWebSocketProtocol = "Sec-WebSocket-Protocol"; + public const string SecWebSocketVersion = "Sec-WebSocket-Version"; + public const string Server = "Server"; + public const string SetCookie = "Set-Cookie"; + public const string SetCookie2 = "Set-Cookie2"; + public const string StrictTransportSecurity = "Strict-Transport-Security"; + public const string TE = "TE"; + public const string TSV = "TSV"; + public const string Trailer = "Trailer"; + public const string TransferEncoding = "Transfer-Encoding"; + public const string Upgrade = "Upgrade"; + public const string UpgradeInsecureRequests = "Upgrade-Insecure-Requests"; + public const string UserAgent = "User-Agent"; + public const string Vary = "Vary"; + public const string Via = "Via"; + public const string WWWAuthenticate = "WWW-Authenticate"; + public const string Warning = "Warning"; + public const string XAspNetVersion = "X-AspNet-Version"; + public const string XContentDuration = "X-Content-Duration"; + public const string XContentTypeOptions = "X-Content-Type-Options"; + public const string XFrameOptions = "X-Frame-Options"; + public const string XMSEdgeRef = "X-MSEdge-Ref"; + public const string XPoweredBy = "X-Powered-By"; + public const string XRequestID = "X-Request-ID"; + public const string XUACompatible = "X-UA-Compatible"; + } +} diff --git a/SocketHttpListener/Net/HttpListener.cs b/SocketHttpListener/Net/HttpListener.cs index 32c5e90e0..759be64c9 100644 --- a/SocketHttpListener/Net/HttpListener.cs +++ b/SocketHttpListener/Net/HttpListener.cs @@ -21,7 +21,7 @@ namespace SocketHttpListener.Net internal ISocketFactory SocketFactory { get; private set; } internal IFileSystem FileSystem { get; private set; } internal ITextEncoding TextEncoding { get; private set; } - internal IMemoryStreamFactory MemoryStreamFactory { get; private set; } + internal IStreamHelper StreamHelper { get; private set; } internal INetworkManager NetworkManager { get; private set; } internal IEnvironmentInfo EnvironmentInfo { get; private set; } @@ -42,14 +42,14 @@ namespace SocketHttpListener.Net public Action<HttpListenerContext> OnContext { get; set; } - public HttpListener(ILogger logger, ICryptoProvider cryptoProvider, ISocketFactory socketFactory, INetworkManager networkManager, ITextEncoding textEncoding, IMemoryStreamFactory memoryStreamFactory, IFileSystem fileSystem, IEnvironmentInfo environmentInfo) + public HttpListener(ILogger logger, ICryptoProvider cryptoProvider, ISocketFactory socketFactory, INetworkManager networkManager, ITextEncoding textEncoding, IStreamHelper streamHelper, IFileSystem fileSystem, IEnvironmentInfo environmentInfo) { _logger = logger; CryptoProvider = cryptoProvider; SocketFactory = socketFactory; NetworkManager = networkManager; TextEncoding = textEncoding; - MemoryStreamFactory = memoryStreamFactory; + StreamHelper = streamHelper; FileSystem = fileSystem; EnvironmentInfo = environmentInfo; prefixes = new HttpListenerPrefixCollection(logger, this); @@ -58,13 +58,13 @@ namespace SocketHttpListener.Net auth_schemes = AuthenticationSchemes.Anonymous; } - public HttpListener(X509Certificate certificate, ICryptoProvider cryptoProvider, ISocketFactory socketFactory, INetworkManager networkManager, ITextEncoding textEncoding, IMemoryStreamFactory memoryStreamFactory, IFileSystem fileSystem, IEnvironmentInfo environmentInfo) - :this(new NullLogger(), certificate, cryptoProvider, socketFactory, networkManager, textEncoding, memoryStreamFactory, fileSystem, environmentInfo) + public HttpListener(X509Certificate certificate, ICryptoProvider cryptoProvider, ISocketFactory socketFactory, INetworkManager networkManager, ITextEncoding textEncoding, IStreamHelper streamHelper, IFileSystem fileSystem, IEnvironmentInfo environmentInfo) + :this(new NullLogger(), certificate, cryptoProvider, socketFactory, networkManager, textEncoding, streamHelper, fileSystem, environmentInfo) { } - public HttpListener(ILogger logger, X509Certificate certificate, ICryptoProvider cryptoProvider, ISocketFactory socketFactory, INetworkManager networkManager, ITextEncoding textEncoding, IMemoryStreamFactory memoryStreamFactory, IFileSystem fileSystem, IEnvironmentInfo environmentInfo) - : this(logger, cryptoProvider, socketFactory, networkManager, textEncoding, memoryStreamFactory, fileSystem, environmentInfo) + public HttpListener(ILogger logger, X509Certificate certificate, ICryptoProvider cryptoProvider, ISocketFactory socketFactory, INetworkManager networkManager, ITextEncoding textEncoding, IStreamHelper streamHelper, IFileSystem fileSystem, IEnvironmentInfo environmentInfo) + : this(logger, cryptoProvider, socketFactory, networkManager, textEncoding, streamHelper, fileSystem, environmentInfo) { _certificate = certificate; } @@ -185,7 +185,7 @@ namespace SocketHttpListener.Net void Close(bool force) { CheckDisposed(); - EndPointManager.RemoveListener(_logger, this); + HttpEndPointManager.RemoveListener(_logger, this); Cleanup(force); } @@ -230,7 +230,7 @@ namespace SocketHttpListener.Net if (listening) return; - EndPointManager.AddListener(_logger, this); + HttpEndPointManager.AddListener(_logger, this); listening = true; } @@ -248,7 +248,6 @@ namespace SocketHttpListener.Net Close(true); //TODO: Should we force here or not? disposed = true; - GC.SuppressFinalize(this); } internal void CheckDisposed() diff --git a/SocketHttpListener/Net/HttpListenerContext.Managed.cs b/SocketHttpListener/Net/HttpListenerContext.Managed.cs new file mode 100644 index 000000000..db795f742 --- /dev/null +++ b/SocketHttpListener/Net/HttpListenerContext.Managed.cs @@ -0,0 +1,100 @@ +using System.ComponentModel; +using System.Security.Principal; +using System.Text; +using System.Threading.Tasks; +using System; +using MediaBrowser.Model.Text; +using SocketHttpListener.Net.WebSockets; + +namespace SocketHttpListener.Net +{ + public sealed unsafe partial class HttpListenerContext + { + private HttpConnection _connection; + + internal HttpListenerContext(HttpConnection connection, ITextEncoding textEncoding) + { + _connection = connection; + _response = new HttpListenerResponse(this, textEncoding); + Request = new HttpListenerRequest(this); + ErrorStatus = 400; + } + + internal int ErrorStatus { get; set; } + + internal string ErrorMessage { get; set; } + + internal bool HaveError => ErrorMessage != null; + + internal HttpConnection Connection => _connection; + + internal void ParseAuthentication(System.Net.AuthenticationSchemes expectedSchemes) + { + if (expectedSchemes == System.Net.AuthenticationSchemes.Anonymous) + return; + + string header = Request.Headers["Authorization"]; + if (string.IsNullOrEmpty(header)) + return; + + if (IsBasicHeader(header)) + { + _user = ParseBasicAuthentication(header.Substring(AuthenticationTypes.Basic.Length + 1)); + } + } + + internal IPrincipal ParseBasicAuthentication(string authData) => + TryParseBasicAuth(authData, out HttpStatusCode errorCode, out string username, out string password) ? + new GenericPrincipal(new HttpListenerBasicIdentity(username, password), Array.Empty<string>()) : + null; + + internal static bool IsBasicHeader(string header) => + header.Length >= 6 && + header[5] == ' ' && + string.Compare(header, 0, AuthenticationTypes.Basic, 0, 5, StringComparison.OrdinalIgnoreCase) == 0; + + internal static bool TryParseBasicAuth(string headerValue, out HttpStatusCode errorCode, out string username, out string password) + { + errorCode = HttpStatusCode.OK; + username = password = null; + try + { + if (string.IsNullOrWhiteSpace(headerValue)) + { + return false; + } + + string authString = Encoding.UTF8.GetString(Convert.FromBase64String(headerValue)); + int colonPos = authString.IndexOf(':'); + if (colonPos < 0) + { + // username must be at least 1 char + errorCode = HttpStatusCode.BadRequest; + return false; + } + + username = authString.Substring(0, colonPos); + password = authString.Substring(colonPos + 1); + return true; + } + catch + { + errorCode = HttpStatusCode.InternalServerError; + return false; + } + } + + public Task<HttpListenerWebSocketContext> AcceptWebSocketAsync(string subProtocol, int receiveBufferSize, TimeSpan keepAliveInterval) + { + return HttpWebSocket.AcceptWebSocketAsyncCore(this, subProtocol, receiveBufferSize, keepAliveInterval); + } + + [EditorBrowsable(EditorBrowsableState.Never)] + public Task<HttpListenerWebSocketContext> AcceptWebSocketAsync(string subProtocol, int receiveBufferSize, TimeSpan keepAliveInterval, ArraySegment<byte> internalBuffer) + { + WebSocketValidate.ValidateArraySegment(internalBuffer, nameof(internalBuffer)); + HttpWebSocket.ValidateOptions(subProtocol, receiveBufferSize, HttpWebSocket.MinSendBufferSize, keepAliveInterval); + return HttpWebSocket.AcceptWebSocketAsyncCore(this, subProtocol, receiveBufferSize, keepAliveInterval, internalBuffer); + } + } +} diff --git a/SocketHttpListener/Net/HttpListenerContext.cs b/SocketHttpListener/Net/HttpListenerContext.cs index 1bf39589d..f4679568a 100644 --- a/SocketHttpListener/Net/HttpListenerContext.cs +++ b/SocketHttpListener/Net/HttpListenerContext.cs @@ -7,145 +7,39 @@ using MediaBrowser.Model.Logging; using MediaBrowser.Model.Text; using SocketHttpListener.Net.WebSockets; using SocketHttpListener.Primitives; +using System.Threading.Tasks; namespace SocketHttpListener.Net { - public sealed class HttpListenerContext + public sealed unsafe partial class HttpListenerContext { - HttpListenerRequest request; - HttpListenerResponse response; - IPrincipal user; - HttpConnection cnc; - string error; - int err_status = 400; - private readonly ICryptoProvider _cryptoProvider; - private readonly IMemoryStreamFactory _memoryStreamFactory; - private readonly ITextEncoding _textEncoding; + internal HttpListener _listener; + private HttpListenerResponse _response; + private IPrincipal _user; - internal HttpListenerContext(HttpConnection cnc, ILogger logger, ICryptoProvider cryptoProvider, IMemoryStreamFactory memoryStreamFactory, ITextEncoding textEncoding, IFileSystem fileSystem) - { - this.cnc = cnc; - _cryptoProvider = cryptoProvider; - _memoryStreamFactory = memoryStreamFactory; - _textEncoding = textEncoding; - request = new HttpListenerRequest(this, _textEncoding); - response = new HttpListenerResponse(this, _textEncoding); - } - - internal int ErrorStatus - { - get { return err_status; } - set { err_status = value; } - } - - internal string ErrorMessage - { - get { return error; } - set { error = value; } - } - - internal bool HaveError - { - get { return (error != null); } - } + public HttpListenerRequest Request { get; } - internal HttpConnection Connection - { - get { return cnc; } - } + public IPrincipal User => _user; - public HttpListenerRequest Request - { - get { return request; } - } + // This can be used to cache the results of HttpListener.AuthenticationSchemeSelectorDelegate. + internal AuthenticationSchemes AuthenticationSchemes { get; set; } public HttpListenerResponse Response { - get { return response; } - } - - public IPrincipal User - { - get { return user; } - } - - internal void ParseAuthentication(AuthenticationSchemes expectedSchemes) - { - if (expectedSchemes == AuthenticationSchemes.Anonymous) - return; - - // TODO: Handle NTLM/Digest modes - string header = request.Headers["Authorization"]; - if (header == null || header.Length < 2) - return; - - string[] authenticationData = header.Split(new char[] { ' ' }, 2); - if (string.Equals(authenticationData[0], "basic", StringComparison.OrdinalIgnoreCase)) + get { - user = ParseBasicAuthentication(authenticationData[1]); + return _response; } - // TODO: throw if malformed -> 400 bad request } - internal IPrincipal ParseBasicAuthentication(string authData) + public Task<HttpListenerWebSocketContext> AcceptWebSocketAsync(string subProtocol) { - try - { - // Basic AUTH Data is a formatted Base64 String - //string domain = null; - string user = null; - string password = null; - int pos = -1; - var authDataBytes = Convert.FromBase64String(authData); - string authString = _textEncoding.GetDefaultEncoding().GetString(authDataBytes, 0, authDataBytes.Length); - - // The format is DOMAIN\username:password - // Domain is optional - - pos = authString.IndexOf(':'); - - // parse the password off the end - password = authString.Substring(pos + 1); - - // discard the password - authString = authString.Substring(0, pos); - - // check if there is a domain - pos = authString.IndexOf('\\'); - - if (pos > 0) - { - //domain = authString.Substring (0, pos); - user = authString.Substring(pos); - } - else - { - user = authString; - } - - HttpListenerBasicIdentity identity = new HttpListenerBasicIdentity(user, password); - // TODO: What are the roles MS sets - return new GenericPrincipal(identity, new string[0]); - } - catch (Exception) - { - // Invalid auth data is swallowed silently - return null; - } + return AcceptWebSocketAsync(subProtocol, HttpWebSocket.DefaultReceiveBufferSize, WebSocket.DefaultKeepAliveInterval); } - public HttpListenerWebSocketContext AcceptWebSocket(string protocol) + public Task<HttpListenerWebSocketContext> AcceptWebSocketAsync(string subProtocol, TimeSpan keepAliveInterval) { - if (protocol != null) - { - if (protocol.Length == 0) - throw new ArgumentException("An empty string.", "protocol"); - - if (!protocol.IsToken()) - throw new ArgumentException("Contains an invalid character.", "protocol"); - } - - return new HttpListenerWebSocketContext(this, protocol, _cryptoProvider, _memoryStreamFactory); + return AcceptWebSocketAsync(subProtocol, HttpWebSocket.DefaultReceiveBufferSize, keepAliveInterval); } } diff --git a/SocketHttpListener/Net/HttpListenerPrefixCollection.cs b/SocketHttpListener/Net/HttpListenerPrefixCollection.cs index 0b05539ee..53efcb0fa 100644 --- a/SocketHttpListener/Net/HttpListenerPrefixCollection.cs +++ b/SocketHttpListener/Net/HttpListenerPrefixCollection.cs @@ -36,13 +36,13 @@ namespace SocketHttpListener.Net public void Add(string uriPrefix) { listener.CheckDisposed(); - ListenerPrefix.CheckUri(uriPrefix); + //ListenerPrefix.CheckUri(uriPrefix); if (prefixes.Contains(uriPrefix)) return; prefixes.Add(uriPrefix); if (listener.IsListening) - EndPointManager.AddPrefix(_logger, uriPrefix, listener); + HttpEndPointManager.AddPrefix(_logger, uriPrefix, listener); } public void Clear() @@ -50,7 +50,7 @@ namespace SocketHttpListener.Net listener.CheckDisposed(); prefixes.Clear(); if (listener.IsListening) - EndPointManager.RemoveListener(_logger, listener); + HttpEndPointManager.RemoveListener(_logger, listener); } public bool Contains(string uriPrefix) @@ -89,7 +89,7 @@ namespace SocketHttpListener.Net bool result = prefixes.Remove(uriPrefix); if (result && listener.IsListening) - EndPointManager.RemovePrefix(_logger, uriPrefix, listener); + HttpEndPointManager.RemovePrefix(_logger, uriPrefix, listener); return result; } diff --git a/SocketHttpListener/Net/HttpListenerRequest.Managed.cs b/SocketHttpListener/Net/HttpListenerRequest.Managed.cs new file mode 100644 index 000000000..47a6dfcfd --- /dev/null +++ b/SocketHttpListener/Net/HttpListenerRequest.Managed.cs @@ -0,0 +1,330 @@ +using System; +using System.Text; +using System.Collections.Specialized; +using System.Globalization; +using System.IO; +using System.Security.Authentication.ExtendedProtection; +using System.Security.Cryptography.X509Certificates; +using MediaBrowser.Model.Services; +using MediaBrowser.Model.Text; + +namespace SocketHttpListener.Net +{ + public sealed partial class HttpListenerRequest + { + private long _contentLength; + private bool _clSet; + private WebHeaderCollection _headers; + private string _method; + private Stream _inputStream; + private HttpListenerContext _context; + private bool _isChunked; + + private static byte[] s_100continue = Encoding.ASCII.GetBytes("HTTP/1.1 100 Continue\r\n\r\n"); + + internal HttpListenerRequest(HttpListenerContext context) + { + _context = context; + _headers = new WebHeaderCollection(); + _version = HttpVersion.Version10; + } + + private static readonly char[] s_separators = new char[] { ' ' }; + + internal void SetRequestLine(string req) + { + string[] parts = req.Split(s_separators, 3); + if (parts.Length != 3) + { + _context.ErrorMessage = "Invalid request line (parts)."; + return; + } + + _method = parts[0]; + foreach (char c in _method) + { + int ic = (int)c; + + if ((ic >= 'A' && ic <= 'Z') || + (ic > 32 && c < 127 && c != '(' && c != ')' && c != '<' && + c != '<' && c != '>' && c != '@' && c != ',' && c != ';' && + c != ':' && c != '\\' && c != '"' && c != '/' && c != '[' && + c != ']' && c != '?' && c != '=' && c != '{' && c != '}')) + continue; + + _context.ErrorMessage = "(Invalid verb)"; + return; + } + + _rawUrl = parts[1]; + if (parts[2].Length != 8 || !parts[2].StartsWith("HTTP/")) + { + _context.ErrorMessage = "Invalid request line (version)."; + return; + } + + try + { + _version = new Version(parts[2].Substring(5)); + } + catch + { + _context.ErrorMessage = "Invalid request line (version)."; + return; + } + + if (_version.Major < 1) + { + _context.ErrorMessage = "Invalid request line (version)."; + return; + } + if (_version.Major > 1) + { + _context.ErrorStatus = (int)HttpStatusCode.HttpVersionNotSupported; + _context.ErrorMessage = HttpStatusDescription.Get(HttpStatusCode.HttpVersionNotSupported); + return; + } + } + + private static bool MaybeUri(string s) + { + int p = s.IndexOf(':'); + if (p == -1) + return false; + + if (p >= 10) + return false; + + return IsPredefinedScheme(s.Substring(0, p)); + } + + private static bool IsPredefinedScheme(string scheme) + { + if (scheme == null || scheme.Length < 3) + return false; + + char c = scheme[0]; + if (c == 'h') + return (scheme == UriScheme.Http || scheme == UriScheme.Https); + if (c == 'f') + return (scheme == UriScheme.File || scheme == UriScheme.Ftp); + + if (c == 'n') + { + c = scheme[1]; + if (c == 'e') + return (scheme == UriScheme.News || scheme == UriScheme.NetPipe || scheme == UriScheme.NetTcp); + if (scheme == UriScheme.Nntp) + return true; + return false; + } + if ((c == 'g' && scheme == UriScheme.Gopher) || (c == 'm' && scheme == UriScheme.Mailto)) + return true; + + return false; + } + + internal void FinishInitialization() + { + string host = UserHostName; + if (_version > HttpVersion.Version10 && (host == null || host.Length == 0)) + { + _context.ErrorMessage = "Invalid host name"; + return; + } + + string path; + Uri raw_uri = null; + if (MaybeUri(_rawUrl.ToLowerInvariant()) && Uri.TryCreate(_rawUrl, UriKind.Absolute, out raw_uri)) + path = raw_uri.PathAndQuery; + else + path = _rawUrl; + + if ((host == null || host.Length == 0)) + host = UserHostAddress; + + if (raw_uri != null) + host = raw_uri.Host; + + int colon = host.IndexOf(']') == -1 ? host.IndexOf(':') : host.LastIndexOf(':'); + if (colon >= 0) + host = host.Substring(0, colon); + + string base_uri = string.Format("{0}://{1}:{2}", RequestScheme, host, LocalEndPoint.Port); + + if (!Uri.TryCreate(base_uri + path, UriKind.Absolute, out _requestUri)) + { + _context.ErrorMessage = System.Net.WebUtility.HtmlEncode("Invalid url: " + base_uri + path); + return; + } + + _requestUri = HttpListenerRequestUriBuilder.GetRequestUri(_rawUrl, _requestUri.Scheme, + _requestUri.Authority, _requestUri.LocalPath, _requestUri.Query); + + if (_version >= HttpVersion.Version11) + { + string t_encoding = Headers[HttpKnownHeaderNames.TransferEncoding]; + _isChunked = (t_encoding != null && string.Equals(t_encoding, "chunked", StringComparison.OrdinalIgnoreCase)); + // 'identity' is not valid! + if (t_encoding != null && !_isChunked) + { + _context.Connection.SendError(null, 501); + return; + } + } + + if (!_isChunked && !_clSet) + { + if (string.Equals(_method, "POST", StringComparison.OrdinalIgnoreCase) || + string.Equals(_method, "PUT", StringComparison.OrdinalIgnoreCase)) + { + _context.Connection.SendError(null, 411); + return; + } + } + + if (String.Compare(Headers[HttpKnownHeaderNames.Expect], "100-continue", StringComparison.OrdinalIgnoreCase) == 0) + { + HttpResponseStream output = _context.Connection.GetResponseStream(); + output.InternalWrite(s_100continue, 0, s_100continue.Length); + } + } + + internal static string Unquote(String str) + { + int start = str.IndexOf('\"'); + int end = str.LastIndexOf('\"'); + if (start >= 0 && end >= 0) + str = str.Substring(start + 1, end - 1); + return str.Trim(); + } + + internal void AddHeader(string header) + { + int colon = header.IndexOf(':'); + if (colon == -1 || colon == 0) + { + _context.ErrorMessage = HttpStatusDescription.Get(400); + _context.ErrorStatus = 400; + return; + } + + string name = header.Substring(0, colon).Trim(); + string val = header.Substring(colon + 1).Trim(); + if (name.Equals("content-length", StringComparison.OrdinalIgnoreCase)) + { + // To match Windows behavior: + // Content lengths >= 0 and <= long.MaxValue are accepted as is. + // Content lengths > long.MaxValue and <= ulong.MaxValue are treated as 0. + // Content lengths < 0 cause the requests to fail. + // Other input is a failure, too. + long parsedContentLength = + ulong.TryParse(val, out ulong parsedUlongContentLength) ? (parsedUlongContentLength <= long.MaxValue ? (long)parsedUlongContentLength : 0) : + long.Parse(val); + if (parsedContentLength < 0 || (_clSet && parsedContentLength != _contentLength)) + { + _context.ErrorMessage = "Invalid Content-Length."; + } + else + { + _contentLength = parsedContentLength; + _clSet = true; + } + } + else if (name.Equals("transfer-encoding", StringComparison.OrdinalIgnoreCase)) + { + if (Headers[HttpKnownHeaderNames.TransferEncoding] != null) + { + _context.ErrorStatus = (int)HttpStatusCode.NotImplemented; + _context.ErrorMessage = HttpStatusDescription.Get(HttpStatusCode.NotImplemented); + } + } + + if (_context.ErrorMessage == null) + { + _headers.Set(name, val); + } + } + + // returns true is the stream could be reused. + internal bool FlushInput() + { + if (!HasEntityBody) + return true; + + int length = 2048; + if (_contentLength > 0) + length = (int)Math.Min(_contentLength, (long)length); + + byte[] bytes = new byte[length]; + while (true) + { + try + { + IAsyncResult ares = InputStream.BeginRead(bytes, 0, length, null, null); + if (!ares.IsCompleted && !ares.AsyncWaitHandle.WaitOne(1000)) + return false; + if (InputStream.EndRead(ares) <= 0) + return true; + } + catch (ObjectDisposedException) + { + _inputStream = null; + return true; + } + catch + { + return false; + } + } + } + + public long ContentLength64 + { + get + { + if (_isChunked) + _contentLength = -1; + + return _contentLength; + } + } + + public bool HasEntityBody => (_contentLength > 0 || _isChunked); + + public QueryParamCollection Headers => _headers; + + public string HttpMethod => _method; + + public Stream InputStream + { + get + { + if (_inputStream == null) + { + if (_isChunked || _contentLength > 0) + _inputStream = _context.Connection.GetRequestStream(_isChunked, _contentLength); + else + _inputStream = Stream.Null; + } + + return _inputStream; + } + } + + public bool IsAuthenticated => false; + + public bool IsSecureConnection => _context.Connection.IsSecure; + + public System.Net.IPEndPoint LocalEndPoint => _context.Connection.LocalEndPoint; + + public System.Net.IPEndPoint RemoteEndPoint => _context.Connection.RemoteEndPoint; + + public Guid RequestTraceIdentifier { get; } = Guid.NewGuid(); + + public string ServiceName => null; + + private Uri RequestUri => _requestUri; + private bool SupportsWebSockets => true; + } +} diff --git a/SocketHttpListener/Net/HttpListenerRequest.cs b/SocketHttpListener/Net/HttpListenerRequest.cs index 5e391424f..1b369dfa8 100644 --- a/SocketHttpListener/Net/HttpListenerRequest.cs +++ b/SocketHttpListener/Net/HttpListenerRequest.cs @@ -10,653 +10,534 @@ using MediaBrowser.Model.Net; using MediaBrowser.Model.Services; using MediaBrowser.Model.Text; using SocketHttpListener.Primitives; +using System.Collections.Generic; +using SocketHttpListener.Net.WebSockets; namespace SocketHttpListener.Net { - public sealed class HttpListenerRequest + public sealed unsafe partial class HttpListenerRequest { - string[] accept_types; - Encoding content_encoding; - long content_length; - bool cl_set; - CookieCollection cookies; - WebHeaderCollection headers; - string method; - Stream input_stream; - Version version; - QueryParamCollection query_string; // check if null is ok, check if read-only, check case-sensitiveness - string raw_url; - Uri url; - Uri referrer; - string[] user_languages; - HttpListenerContext context; - bool is_chunked; - bool ka_set; - bool? _keepAlive; - - private readonly ITextEncoding _textEncoding; - - internal HttpListenerRequest(HttpListenerContext context, ITextEncoding textEncoding) - { - this.context = context; - _textEncoding = textEncoding; - headers = new WebHeaderCollection(); - version = HttpVersion.Version10; - } + private CookieCollection _cookies; + private bool? _keepAlive; + private string _rawUrl; + private Uri _requestUri; + private Version _version; - static char[] separators = new char[] { ' ' }; + public string[] AcceptTypes => Helpers.ParseMultivalueHeader(Headers[HttpKnownHeaderNames.Accept]); - internal void SetRequestLine(string req) - { - string[] parts = req.Split(separators, 3); - if (parts.Length != 3) - { - context.ErrorMessage = "Invalid request line (parts)."; - return; - } + public string[] UserLanguages => Helpers.ParseMultivalueHeader(Headers[HttpKnownHeaderNames.AcceptLanguage]); - method = parts[0]; - foreach (char c in method) - { - int ic = (int)c; - - if ((ic >= 'A' && ic <= 'Z') || - (ic > 32 && c < 127 && c != '(' && c != ')' && c != '<' && - c != '<' && c != '>' && c != '@' && c != ',' && c != ';' && - c != ':' && c != '\\' && c != '"' && c != '/' && c != '[' && - c != ']' && c != '?' && c != '=' && c != '{' && c != '}')) - continue; - - context.ErrorMessage = "(Invalid verb)"; - return; - } - - raw_url = parts[1]; - if (parts[2].Length != 8 || !parts[2].StartsWith("HTTP/")) - { - context.ErrorMessage = "Invalid request line (version)."; - return; - } - - try - { - version = new Version(parts[2].Substring(5)); - if (version.Major < 1) - throw new Exception(); - } - catch - { - context.ErrorMessage = "Invalid request line (version)."; - return; - } + private CookieCollection ParseCookies(Uri uri, string setCookieHeader) + { + CookieCollection cookies = new CookieCollection(); + return cookies; } - void CreateQueryString(string query) + public CookieCollection Cookies { - if (query == null || query.Length == 0) + get { - query_string = new QueryParamCollection(); - return; + if (_cookies == null) + { + string cookieString = Headers[HttpKnownHeaderNames.Cookie]; + if (!string.IsNullOrEmpty(cookieString)) + { + _cookies = ParseCookies(RequestUri, cookieString); + } + if (_cookies == null) + { + _cookies = new CookieCollection(); + } + } + return _cookies; } + } - query_string = new QueryParamCollection(); - if (query[0] == '?') - query = query.Substring(1); - string[] components = query.Split('&'); - foreach (string kv in components) + public Encoding ContentEncoding + { + get { - int pos = kv.IndexOf('='); - if (pos == -1) + if (UserAgent != null && CultureInfo.InvariantCulture.CompareInfo.IsPrefix(UserAgent, "UP")) { - query_string.Add(null, WebUtility.UrlDecode(kv)); + string postDataCharset = Headers["x-up-devcap-post-charset"]; + if (postDataCharset != null && postDataCharset.Length > 0) + { + try + { + return Encoding.GetEncoding(postDataCharset); + } + catch (ArgumentException) + { + } + } } - else + if (HasEntityBody) { - string key = WebUtility.UrlDecode(kv.Substring(0, pos)); - string val = WebUtility.UrlDecode(kv.Substring(pos + 1)); - - query_string.Add(key, val); + if (ContentType != null) + { + string charSet = Helpers.GetCharSetValueFromHeader(ContentType); + if (charSet != null) + { + try + { + return Encoding.GetEncoding(charSet); + } + catch (ArgumentException) + { + } + } + } } + return TextEncodingExtensions.GetDefaultEncoding(); } } - internal void FinishInitialization() - { - string host = UserHostName; - if (version > HttpVersion.Version10 && (host == null || host.Length == 0)) - { - context.ErrorMessage = "Invalid host name"; - return; - } - - string path; - Uri raw_uri = null; - if (MaybeUri(raw_url.ToLowerInvariant()) && Uri.TryCreate(raw_url, UriKind.Absolute, out raw_uri)) - path = raw_uri.PathAndQuery; - else - path = raw_url; - - if ((host == null || host.Length == 0)) - host = UserHostAddress; - - if (raw_uri != null) - host = raw_uri.Host; - - int colon = host.LastIndexOf(':'); - if (colon >= 0) - host = host.Substring(0, colon); + public string ContentType => Headers[HttpKnownHeaderNames.ContentType]; - string base_uri = String.Format("{0}://{1}:{2}", - (IsSecureConnection) ? (IsWebSocketRequest ? "wss" : "https") : (IsWebSocketRequest ? "ws" : "http"), - host, LocalEndPoint.Port); + public bool IsLocal => LocalEndPoint.Address.Equals(RemoteEndPoint.Address); - if (!Uri.TryCreate(base_uri + path, UriKind.Absolute, out url)) - { - context.ErrorMessage = WebUtility.HtmlEncode("Invalid url: " + base_uri + path); - return; return; - } - - CreateQueryString(url.Query); - - if (version >= HttpVersion.Version11) + public bool IsWebSocketRequest + { + get { - string t_encoding = Headers["Transfer-Encoding"]; - is_chunked = (t_encoding != null && String.Compare(t_encoding, "chunked", StringComparison.OrdinalIgnoreCase) == 0); - // 'identity' is not valid! - if (t_encoding != null && !is_chunked) + if (!SupportsWebSockets) { - context.Connection.SendError(null, 501); - return; + return false; } - } - if (!is_chunked && !cl_set) - { - if (String.Compare(method, "POST", StringComparison.OrdinalIgnoreCase) == 0 || - String.Compare(method, "PUT", StringComparison.OrdinalIgnoreCase) == 0) + bool foundConnectionUpgradeHeader = false; + if (string.IsNullOrEmpty(Headers[HttpKnownHeaderNames.Connection]) || string.IsNullOrEmpty(Headers[HttpKnownHeaderNames.Upgrade])) { - context.Connection.SendError(null, 411); - return; + return false; } - } - - if (String.Compare(Headers["Expect"], "100-continue", StringComparison.OrdinalIgnoreCase) == 0) - { - var output = (HttpResponseStream)context.Connection.GetResponseStream(true); - - var _100continue = _textEncoding.GetASCIIEncoding().GetBytes("HTTP/1.1 100 Continue\r\n\r\n"); - - output.InternalWrite(_100continue, 0, _100continue.Length); - } - } - - static bool MaybeUri(string s) - { - int p = s.IndexOf(':'); - if (p == -1) - return false; - if (p >= 10) - return false; - - return IsPredefinedScheme(s.Substring(0, p)); - } + foreach (string connection in Headers.GetValues(HttpKnownHeaderNames.Connection)) + { + if (string.Equals(connection, HttpKnownHeaderNames.Upgrade, StringComparison.OrdinalIgnoreCase)) + { + foundConnectionUpgradeHeader = true; + break; + } + } - // - // Using a simple block of if's is twice as slow as the compiler generated - // switch statement. But using this tuned code is faster than the - // compiler generated code, with a million loops on x86-64: - // - // With "http": .10 vs .51 (first check) - // with "https": .16 vs .51 (second check) - // with "foo": .22 vs .31 (never found) - // with "mailto": .12 vs .51 (last check) - // - // - static bool IsPredefinedScheme(string scheme) - { - if (scheme == null || scheme.Length < 3) - return false; + if (!foundConnectionUpgradeHeader) + { + return false; + } - char c = scheme[0]; - if (c == 'h') - return (scheme == "http" || scheme == "https"); - if (c == 'f') - return (scheme == "file" || scheme == "ftp"); + foreach (string upgrade in Headers.GetValues(HttpKnownHeaderNames.Upgrade)) + { + if (string.Equals(upgrade, HttpWebSocket.WebSocketUpgradeToken, StringComparison.OrdinalIgnoreCase)) + { + return true; + } + } - if (c == 'n') - { - c = scheme[1]; - if (c == 'e') - return (scheme == "news" || scheme == "net.pipe" || scheme == "net.tcp"); - if (scheme == "nntp") - return true; return false; } - if ((c == 'g' && scheme == "gopher") || (c == 'm' && scheme == "mailto")) - return true; - - return false; } - internal static string Unquote(String str) - { - int start = str.IndexOf('\"'); - int end = str.LastIndexOf('\"'); - if (start >= 0 && end >= 0) - str = str.Substring(start + 1, end - 1); - return str.Trim(); - } - - internal void AddHeader(string header) + public bool KeepAlive { - int colon = header.IndexOf(':'); - if (colon == -1 || colon == 0) - { - context.ErrorMessage = "Bad Request"; - context.ErrorStatus = 400; - return; - } - - string name = header.Substring(0, colon).Trim(); - string val = header.Substring(colon + 1).Trim(); - string lower = name.ToLowerInvariant(); - headers.SetInternal(name, val); - switch (lower) + get { - case "accept-language": - user_languages = val.Split(','); // yes, only split with a ',' - break; - case "accept": - accept_types = val.Split(','); // yes, only split with a ',' - break; - case "content-length": - try - { - //TODO: max. content_length? - content_length = Int64.Parse(val.Trim()); - if (content_length < 0) - context.ErrorMessage = "Invalid Content-Length."; - cl_set = true; - } - catch - { - context.ErrorMessage = "Invalid Content-Length."; - } - - break; - case "content-type": - { - var contents = val.Split(';'); - foreach (var content in contents) - { - var tmp = content.Trim(); - if (tmp.StartsWith("charset")) - { - var charset = tmp.GetValue("="); - if (charset != null && charset.Length > 0) - { - try - { - - // Support upnp/dlna devices - CONTENT-TYPE: text/xml ; charset="utf-8"\r\n - charset = charset.Trim('"'); - var index = charset.IndexOf('"'); - if (index != -1) charset = charset.Substring(0, index); - - content_encoding = Encoding.GetEncoding(charset); - } - catch - { - context.ErrorMessage = "Invalid Content-Type header: " + charset; - } - } - - break; - } - } - } - break; - case "referer": - try - { - referrer = new Uri(val); - } - catch + if (!_keepAlive.HasValue) + { + string header = Headers[HttpKnownHeaderNames.ProxyConnection]; + if (string.IsNullOrEmpty(header)) { - referrer = new Uri("http://someone.is.screwing.with.the.headers.com/"); + header = Headers[HttpKnownHeaderNames.Connection]; } - break; - case "cookie": - if (cookies == null) - cookies = new CookieCollection(); - - string[] cookieStrings = val.Split(new char[] { ',', ';' }); - Cookie current = null; - int version = 0; - foreach (string cookieString in cookieStrings) + if (string.IsNullOrEmpty(header)) { - string str = cookieString.Trim(); - if (str.Length == 0) - continue; - if (str.StartsWith("$Version")) - { - version = Int32.Parse(Unquote(str.Substring(str.IndexOf('=') + 1))); - } - else if (str.StartsWith("$Path")) - { - if (current != null) - current.Path = str.Substring(str.IndexOf('=') + 1).Trim(); - } - else if (str.StartsWith("$Domain")) - { - if (current != null) - current.Domain = str.Substring(str.IndexOf('=') + 1).Trim(); - } - else if (str.StartsWith("$Port")) + if (ProtocolVersion >= HttpVersion.Version11) { - if (current != null) - current.Port = str.Substring(str.IndexOf('=') + 1).Trim(); + _keepAlive = true; } else { - if (current != null) - { - cookies.Add(current); - } - current = new Cookie(); - int idx = str.IndexOf('='); - if (idx > 0) - { - current.Name = str.Substring(0, idx).Trim(); - current.Value = str.Substring(idx + 1).Trim(); - } - else - { - current.Name = str.Trim(); - current.Value = String.Empty; - } - current.Version = version; + header = Headers[HttpKnownHeaderNames.KeepAlive]; + _keepAlive = !string.IsNullOrEmpty(header); } } - if (current != null) - { - cookies.Add(current); - } - break; - } - } - - // returns true is the stream could be reused. - internal bool FlushInput() - { - if (!HasEntityBody) - return true; - - int length = 2048; - if (content_length > 0) - length = (int)Math.Min(content_length, (long)length); - - byte[] bytes = new byte[length]; - while (true) - { - // TODO: test if MS has a timeout when doing this - try - { - var task = InputStream.ReadAsync(bytes, 0, length); - var result = Task.WaitAll(new [] { task }, 1000); - if (!result) - { - return false; - } - if (task.Result <= 0) + else { - return true; + header = header.ToLower(CultureInfo.InvariantCulture); + _keepAlive = + header.IndexOf("close", StringComparison.OrdinalIgnoreCase) < 0 || + header.IndexOf("keep-alive", StringComparison.OrdinalIgnoreCase) >= 0; } } - catch (ObjectDisposedException e) - { - input_stream = null; - return true; - } - catch - { - return false; - } - } - } - public string[] AcceptTypes - { - get { return accept_types; } + return _keepAlive.Value; + } } - public int ClientCertificateError + public QueryParamCollection QueryString { get { - HttpConnection cnc = context.Connection; - //if (cnc.ClientCertificate == null) - // throw new InvalidOperationException("No client certificate"); - //int[] errors = cnc.ClientCertificateErrors; - //if (errors != null && errors.Length > 0) - // return errors[0]; - return 0; + QueryParamCollection queryString = new QueryParamCollection(); + Helpers.FillFromString(queryString, Url.Query, true, ContentEncoding); + return queryString; } } - public Encoding ContentEncoding + public string RawUrl => _rawUrl; + + private string RequestScheme => IsSecureConnection ? UriScheme.Https : UriScheme.Http; + + public string UserAgent => Headers[HttpKnownHeaderNames.UserAgent]; + + public string UserHostAddress => LocalEndPoint.ToString(); + + public string UserHostName => Headers[HttpKnownHeaderNames.Host]; + + public Uri UrlReferrer { get { - if (content_encoding == null) - content_encoding = _textEncoding.GetDefaultEncoding(); - return content_encoding; + string referrer = Headers[HttpKnownHeaderNames.Referer]; + if (referrer == null) + { + return null; + } + + bool success = Uri.TryCreate(referrer, UriKind.RelativeOrAbsolute, out Uri urlReferrer); + return success ? urlReferrer : null; } } - public long ContentLength64 - { - get { return is_chunked ? -1 : content_length; } - } + public Uri Url => RequestUri; - public string ContentType - { - get { return headers["content-type"]; } - } + public Version ProtocolVersion => _version; - public CookieCollection Cookies + private static class Helpers { - get + // + // Get attribute off header value + // + internal static string GetCharSetValueFromHeader(string headerValue) { - // TODO: check if the collection is read-only - if (cookies == null) - cookies = new CookieCollection(); - return cookies; - } - } + const string AttrName = "charset"; - public bool HasEntityBody - { - get { return (content_length > 0 || is_chunked); } - } + if (headerValue == null) + return null; - public QueryParamCollection Headers - { - get { return headers; } - } + int l = headerValue.Length; + int k = AttrName.Length; - public string HttpMethod - { - get { return method; } - } + // find properly separated attribute name + int i = 1; // start searching from 1 - public Stream InputStream - { - get - { - if (input_stream == null) + while (i < l) { - if (is_chunked || content_length > 0) - input_stream = context.Connection.GetRequestStream(is_chunked, content_length); - else - input_stream = Stream.Null; + i = CultureInfo.InvariantCulture.CompareInfo.IndexOf(headerValue, AttrName, i, CompareOptions.IgnoreCase); + if (i < 0) + break; + if (i + k >= l) + break; + + char chPrev = headerValue[i - 1]; + char chNext = headerValue[i + k]; + if ((chPrev == ';' || chPrev == ',' || char.IsWhiteSpace(chPrev)) && (chNext == '=' || char.IsWhiteSpace(chNext))) + break; + + i += k; } - return input_stream; - } - } + if (i < 0 || i >= l) + return null; - public bool IsAuthenticated - { - get { return false; } - } + // skip to '=' and the following whitespace + i += k; + while (i < l && char.IsWhiteSpace(headerValue[i])) + i++; + if (i >= l || headerValue[i] != '=') + return null; + i++; + while (i < l && char.IsWhiteSpace(headerValue[i])) + i++; + if (i >= l) + return null; - public bool IsLocal - { - get + // parse the value + string attrValue = null; + + int j; + + if (i < l && headerValue[i] == '"') + { + if (i == l - 1) + return null; + j = headerValue.IndexOf('"', i + 1); + if (j < 0 || j == i + 1) + return null; + + attrValue = headerValue.Substring(i + 1, j - i - 1).Trim(); + } + else + { + for (j = i; j < l; j++) + { + if (headerValue[j] == ';') + break; + } + + if (j == i) + return null; + + attrValue = headerValue.Substring(i, j - i).Trim(); + } + + return attrValue; + } + + internal static string[] ParseMultivalueHeader(string s) { - var remoteEndPoint = RemoteEndPoint; + if (s == null) + return null; + + int l = s.Length; + + // collect comma-separated values into list + + List<string> values = new List<string>(); + int i = 0; + + while (i < l) + { + // find next , + int ci = s.IndexOf(',', i); + if (ci < 0) + ci = l; + + // append corresponding server value + values.Add(s.Substring(i, ci - i)); + + // move to next + i = ci + 1; + + // skip leading space + if (i < l && s[i] == ' ') + i++; + } - return remoteEndPoint.Address.Equals(IPAddress.Loopback) || - remoteEndPoint.Address.Equals(IPAddress.IPv6Loopback) || - LocalEndPoint.Address.Equals(remoteEndPoint.Address); + // return list as array of strings + + int n = values.Count; + string[] strings; + + // if n is 0 that means s was empty string + + if (n == 0) + { + strings = new string[1]; + strings[0] = string.Empty; + } + else + { + strings = new string[n]; + values.CopyTo(0, strings, 0, n); + } + return strings; } - } - public bool IsSecureConnection - { - get { return context.Connection.IsSecure; } - } - public bool KeepAlive - { - get + private static string UrlDecodeStringFromStringInternal(string s, Encoding e) { - if (!_keepAlive.HasValue) + int count = s.Length; + UrlDecoder helper = new UrlDecoder(count, e); + + // go through the string's chars collapsing %XX and %uXXXX and + // appending each char as char, with exception of %XX constructs + // that are appended as bytes + + for (int pos = 0; pos < count; pos++) { - string header = Headers["Proxy-Connection"]; - if (string.IsNullOrEmpty(header)) + char ch = s[pos]; + + if (ch == '+') { - header = Headers["Connection"]; + ch = ' '; } - if (string.IsNullOrEmpty(header)) + else if (ch == '%' && pos < count - 2) { - if (ProtocolVersion >= HttpVersion.Version11) + if (s[pos + 1] == 'u' && pos < count - 5) { - _keepAlive = true; + int h1 = HexToInt(s[pos + 2]); + int h2 = HexToInt(s[pos + 3]); + int h3 = HexToInt(s[pos + 4]); + int h4 = HexToInt(s[pos + 5]); + + if (h1 >= 0 && h2 >= 0 && h3 >= 0 && h4 >= 0) + { // valid 4 hex chars + ch = (char)((h1 << 12) | (h2 << 8) | (h3 << 4) | h4); + pos += 5; + + // only add as char + helper.AddChar(ch); + continue; + } } else { - header = Headers["Keep-Alive"]; - _keepAlive = !string.IsNullOrEmpty(header); + int h1 = HexToInt(s[pos + 1]); + int h2 = HexToInt(s[pos + 2]); + + if (h1 >= 0 && h2 >= 0) + { // valid 2 hex chars + byte b = (byte)((h1 << 4) | h2); + pos += 2; + + // don't add as char + helper.AddByte(b); + continue; + } } } + + if ((ch & 0xFF80) == 0) + helper.AddByte((byte)ch); // 7 bit have to go as bytes because of Unicode else - { - header = header.ToLower(CultureInfo.InvariantCulture); - _keepAlive = - header.IndexOf("close", StringComparison.OrdinalIgnoreCase) < 0 || - header.IndexOf("keep-alive", StringComparison.OrdinalIgnoreCase) >= 0; - } + helper.AddChar(ch); } - return _keepAlive.Value; + return helper.GetString(); } - } - public IPEndPoint LocalEndPoint - { - get { return context.Connection.LocalEndPoint; } - } + private static int HexToInt(char h) + { + return (h >= '0' && h <= '9') ? h - '0' : + (h >= 'a' && h <= 'f') ? h - 'a' + 10 : + (h >= 'A' && h <= 'F') ? h - 'A' + 10 : + -1; + } - public Version ProtocolVersion - { - get { return version; } - } + private class UrlDecoder + { + private int _bufferSize; - public QueryParamCollection QueryString - { - get { return query_string; } - } + // Accumulate characters in a special array + private int _numChars; + private char[] _charBuffer; - public string RawUrl - { - get { return raw_url; } - } + // Accumulate bytes for decoding into characters in a special array + private int _numBytes; + private byte[] _byteBuffer; - public IPEndPoint RemoteEndPoint - { - get { return context.Connection.RemoteEndPoint; } - } + // Encoding to convert chars to bytes + private Encoding _encoding; - public Guid RequestTraceIdentifier - { - get { return Guid.Empty; } - } + private void FlushBytes() + { + if (_numBytes > 0) + { + _numChars += _encoding.GetChars(_byteBuffer, 0, _numBytes, _charBuffer, _numChars); + _numBytes = 0; + } + } - public Uri Url - { - get { return url; } - } + internal UrlDecoder(int bufferSize, Encoding encoding) + { + _bufferSize = bufferSize; + _encoding = encoding; - public Uri UrlReferrer - { - get { return referrer; } - } + _charBuffer = new char[bufferSize]; + // byte buffer created on demand + } - public string UserAgent - { - get { return headers["user-agent"]; } - } + internal void AddChar(char ch) + { + if (_numBytes > 0) + FlushBytes(); - public string UserHostAddress - { - get { return LocalEndPoint.ToString(); } - } + _charBuffer[_numChars++] = ch; + } - public string UserHostName - { - get { return headers["host"]; } - } + internal void AddByte(byte b) + { + { + if (_byteBuffer == null) + _byteBuffer = new byte[_bufferSize]; - public string[] UserLanguages - { - get { return user_languages; } - } + _byteBuffer[_numBytes++] = b; + } + } - public string ServiceName - { - get - { - return null; + internal string GetString() + { + if (_numBytes > 0) + FlushBytes(); + + if (_numChars > 0) + return new string(_charBuffer, 0, _numChars); + else + return string.Empty; + } } - } - private bool _websocketRequestWasSet; - private bool _websocketRequest; - /// <summary> - /// Gets a value indicating whether the request is a WebSocket connection request. - /// </summary> - /// <value> - /// <c>true</c> if the request is a WebSocket connection request; otherwise, <c>false</c>. - /// </value> - public bool IsWebSocketRequest - { - get + internal static void FillFromString(QueryParamCollection nvc, string s, bool urlencoded, Encoding encoding) { - if (!_websocketRequestWasSet) + int l = (s != null) ? s.Length : 0; + int i = (s.Length > 0 && s[0] == '?') ? 1 : 0; + + while (i < l) { - _websocketRequest = method == "GET" && - version > HttpVersion.Version10 && - headers.Contains("Upgrade", "websocket") && - headers.Contains("Connection", "Upgrade"); + // find next & while noting first = on the way (and if there are more) - _websocketRequestWasSet = true; - } + int si = i; + int ti = -1; - return _websocketRequest; + while (i < l) + { + char ch = s[i]; + + if (ch == '=') + { + if (ti < 0) + ti = i; + } + else if (ch == '&') + { + break; + } + + i++; + } + + // extract the name / value pair + + string name = null; + string value = null; + + if (ti >= 0) + { + name = s.Substring(si, ti - si); + value = s.Substring(ti + 1, i - ti - 1); + } + else + { + value = s.Substring(si, i - si); + } + + // add name / value pair to the collection + + if (urlencoded) + nvc.Add( + name == null ? null : UrlDecodeStringFromStringInternal(name, encoding), + UrlDecodeStringFromStringInternal(value, encoding)); + else + nvc.Add(name, value); + + // trailing '&' + + if (i == l - 1 && s[i] == '&') + nvc.Add(null, ""); + + i++; + } } } } diff --git a/SocketHttpListener/Net/HttpListenerRequestUriBuilder.cs b/SocketHttpListener/Net/HttpListenerRequestUriBuilder.cs new file mode 100644 index 000000000..e61bde32e --- /dev/null +++ b/SocketHttpListener/Net/HttpListenerRequestUriBuilder.cs @@ -0,0 +1,445 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using System.Globalization; + +namespace SocketHttpListener.Net +{ + // We don't use the cooked URL because http.sys unescapes all percent-encoded values. However, + // we also can't just use the raw Uri, since http.sys supports not only Utf-8, but also ANSI/DBCS and + // Unicode code points. System.Uri only supports Utf-8. + // The purpose of this class is to convert all ANSI, DBCS, and Unicode code points into percent encoded + // Utf-8 characters. + internal sealed class HttpListenerRequestUriBuilder + { + private static readonly Encoding s_utf8Encoding = new UTF8Encoding(false, true); + private static readonly Encoding s_ansiEncoding = Encoding.GetEncoding(0, new EncoderExceptionFallback(), new DecoderExceptionFallback()); + + private readonly string _rawUri; + private readonly string _cookedUriScheme; + private readonly string _cookedUriHost; + private readonly string _cookedUriPath; + private readonly string _cookedUriQuery; + + // This field is used to build the final request Uri string from the Uri parts passed to the ctor. + private StringBuilder _requestUriString; + + // The raw path is parsed by looping through all characters from left to right. 'rawOctets' + // is used to store consecutive percent encoded octets as actual byte values: e.g. for path /pa%C3%84th%2F/ + // rawOctets will be set to { 0xC3, 0x84 } when we reach character 't' and it will be { 0x2F } when + // we reach the final '/'. I.e. after a sequence of percent encoded octets ends, we use rawOctets as + // input to the encoding and percent encode the resulting string into UTF-8 octets. + // + // When parsing ANSI (Latin 1) encoded path '/pa%C4th/', %C4 will be added to rawOctets and when + // we reach 't', the content of rawOctets { 0xC4 } will be fed into the ANSI encoding. The resulting + // string 'Ä' will be percent encoded into UTF-8 octets and appended to requestUriString. The final + // path will be '/pa%C3%84th/', where '%C3%84' is the UTF-8 percent encoded character 'Ä'. + private List<byte> _rawOctets; + private string _rawPath; + + // Holds the final request Uri. + private Uri _requestUri; + + private HttpListenerRequestUriBuilder(string rawUri, string cookedUriScheme, string cookedUriHost, + string cookedUriPath, string cookedUriQuery) + { + _rawUri = rawUri; + _cookedUriScheme = cookedUriScheme; + _cookedUriHost = cookedUriHost; + _cookedUriPath = AddSlashToAsteriskOnlyPath(cookedUriPath); + _cookedUriQuery = cookedUriQuery ?? string.Empty; + } + + public static Uri GetRequestUri(string rawUri, string cookedUriScheme, string cookedUriHost, + string cookedUriPath, string cookedUriQuery) + { + HttpListenerRequestUriBuilder builder = new HttpListenerRequestUriBuilder(rawUri, + cookedUriScheme, cookedUriHost, cookedUriPath, cookedUriQuery); + + return builder.Build(); + } + + private Uri Build() + { + BuildRequestUriUsingRawPath(); + + if (_requestUri == null) + { + BuildRequestUriUsingCookedPath(); + } + + return _requestUri; + } + + private void BuildRequestUriUsingCookedPath() + { + bool isValid = Uri.TryCreate(_cookedUriScheme + Uri.SchemeDelimiter + _cookedUriHost + _cookedUriPath + + _cookedUriQuery, UriKind.Absolute, out _requestUri); + + // Creating a Uri from the cooked Uri should really always work: If not, we log at least. + if (!isValid) + { + //if (NetEventSource.IsEnabled) + // NetEventSource.Error(this, SR.Format(SR.net_log_listener_cant_create_uri, _cookedUriScheme, _cookedUriHost, _cookedUriPath, _cookedUriQuery)); + } + } + + private void BuildRequestUriUsingRawPath() + { + bool isValid = false; + + // Initialize 'rawPath' only if really needed; i.e. if we build the request Uri from the raw Uri. + _rawPath = GetPath(_rawUri); + + // Try to check the raw path using first the primary encoding (according to http.sys settings); + // if it fails try the secondary encoding. + ParsingResult result = BuildRequestUriUsingRawPath(GetEncoding(EncodingType.Primary)); + if (result == ParsingResult.EncodingError) + { + Encoding secondaryEncoding = GetEncoding(EncodingType.Secondary); + result = BuildRequestUriUsingRawPath(secondaryEncoding); + } + isValid = (result == ParsingResult.Success) ? true : false; + + // Log that we weren't able to create a Uri from the raw string. + if (!isValid) + { + //if (NetEventSource.IsEnabled) + // NetEventSource.Error(this, SR.Format(SR.net_log_listener_cant_create_uri, _cookedUriScheme, _cookedUriHost, _rawPath, _cookedUriQuery)); + } + } + + private static Encoding GetEncoding(EncodingType type) + { + Debug.Assert((type == EncodingType.Primary) || (type == EncodingType.Secondary), + "Unknown 'EncodingType' value: " + type.ToString()); + + if (type == EncodingType.Secondary) + { + return s_ansiEncoding; + } + else + { + return s_utf8Encoding; + } + } + + private ParsingResult BuildRequestUriUsingRawPath(Encoding encoding) + { + Debug.Assert(encoding != null, "'encoding' must be assigned."); + Debug.Assert(!string.IsNullOrEmpty(_rawPath), "'rawPath' must have at least one character."); + + _rawOctets = new List<byte>(); + _requestUriString = new StringBuilder(); + _requestUriString.Append(_cookedUriScheme); + _requestUriString.Append(Uri.SchemeDelimiter); + _requestUriString.Append(_cookedUriHost); + + ParsingResult result = ParseRawPath(encoding); + if (result == ParsingResult.Success) + { + _requestUriString.Append(_cookedUriQuery); + + Debug.Assert(_rawOctets.Count == 0, + "Still raw octets left. They must be added to the result path."); + + if (!Uri.TryCreate(_requestUriString.ToString(), UriKind.Absolute, out _requestUri)) + { + // If we can't create a Uri from the string, this is an invalid string and it doesn't make + // sense to try another encoding. + result = ParsingResult.InvalidString; + } + } + + if (result != ParsingResult.Success) + { + //if (NetEventSource.IsEnabled) + // NetEventSource.Error(this, SR.Format(SR.net_log_listener_cant_convert_raw_path, _rawPath, encoding.EncodingName)); + } + + return result; + } + + private ParsingResult ParseRawPath(Encoding encoding) + { + Debug.Assert(encoding != null, "'encoding' must be assigned."); + + int index = 0; + char current = '\0'; + while (index < _rawPath.Length) + { + current = _rawPath[index]; + if (current == '%') + { + // Assert is enough, since http.sys accepted the request string already. This should never happen. + Debug.Assert(index + 2 < _rawPath.Length, "Expected >=2 characters after '%' (e.g. %2F)"); + + index++; + current = _rawPath[index]; + if (current == 'u' || current == 'U') + { + // We found "%u" which means, we have a Unicode code point of the form "%uXXXX". + Debug.Assert(index + 4 < _rawPath.Length, "Expected >=4 characters after '%u' (e.g. %u0062)"); + + // Decode the content of rawOctets into percent encoded UTF-8 characters and append them + // to requestUriString. + if (!EmptyDecodeAndAppendRawOctetsList(encoding)) + { + return ParsingResult.EncodingError; + } + if (!AppendUnicodeCodePointValuePercentEncoded(_rawPath.Substring(index + 1, 4))) + { + return ParsingResult.InvalidString; + } + index += 5; + } + else + { + // We found '%', but not followed by 'u', i.e. we have a percent encoded octed: %XX + if (!AddPercentEncodedOctetToRawOctetsList(encoding, _rawPath.Substring(index, 2))) + { + return ParsingResult.InvalidString; + } + index += 2; + } + } + else + { + // We found a non-'%' character: decode the content of rawOctets into percent encoded + // UTF-8 characters and append it to the result. + if (!EmptyDecodeAndAppendRawOctetsList(encoding)) + { + return ParsingResult.EncodingError; + } + // Append the current character to the result. + _requestUriString.Append(current); + index++; + } + } + + // if the raw path ends with a sequence of percent encoded octets, make sure those get added to the + // result (requestUriString). + if (!EmptyDecodeAndAppendRawOctetsList(encoding)) + { + return ParsingResult.EncodingError; + } + + return ParsingResult.Success; + } + + private bool AppendUnicodeCodePointValuePercentEncoded(string codePoint) + { + // http.sys only supports %uXXXX (4 hex-digits), even though unicode code points could have up to + // 6 hex digits. Therefore we parse always 4 characters after %u and convert them to an int. + int codePointValue; + if (!int.TryParse(codePoint, NumberStyles.HexNumber, null, out codePointValue)) + { + //if (NetEventSource.IsEnabled) + // NetEventSource.Error(this, SR.Format(SR.net_log_listener_cant_convert_percent_value, codePoint)); + return false; + } + + string unicodeString = null; + try + { + unicodeString = char.ConvertFromUtf32(codePointValue); + AppendOctetsPercentEncoded(_requestUriString, s_utf8Encoding.GetBytes(unicodeString)); + + return true; + } + catch (ArgumentOutOfRangeException) + { + //if (NetEventSource.IsEnabled) + // NetEventSource.Error(this, SR.Format(SR.net_log_listener_cant_convert_percent_value, codePoint)); + } + catch (EncoderFallbackException e) + { + // If utf8Encoding.GetBytes() fails + //if (NetEventSource.IsEnabled) NetEventSource.Error(this, SR.Format(SR.net_log_listener_cant_convert_to_utf8, unicodeString, e.Message)); + } + + return false; + } + + private bool AddPercentEncodedOctetToRawOctetsList(Encoding encoding, string escapedCharacter) + { + byte encodedValue; + if (!byte.TryParse(escapedCharacter, NumberStyles.HexNumber, null, out encodedValue)) + { + //if (NetEventSource.IsEnabled) NetEventSource.Error(this, SR.Format(SR.net_log_listener_cant_convert_percent_value, escapedCharacter)); + return false; + } + + _rawOctets.Add(encodedValue); + + return true; + } + + private bool EmptyDecodeAndAppendRawOctetsList(Encoding encoding) + { + if (_rawOctets.Count == 0) + { + return true; + } + + string decodedString = null; + try + { + // If the encoding can get a string out of the byte array, this is a valid string in the + // 'encoding' encoding. + decodedString = encoding.GetString(_rawOctets.ToArray()); + + if (encoding == s_utf8Encoding) + { + AppendOctetsPercentEncoded(_requestUriString, _rawOctets.ToArray()); + } + else + { + AppendOctetsPercentEncoded(_requestUriString, s_utf8Encoding.GetBytes(decodedString)); + } + + _rawOctets.Clear(); + + return true; + } + catch (DecoderFallbackException e) + { + //if (NetEventSource.IsEnabled) NetEventSource.Error(this, SR.Format(SR.net_log_listener_cant_convert_bytes, GetOctetsAsString(_rawOctets), e.Message)); + } + catch (EncoderFallbackException e) + { + // If utf8Encoding.GetBytes() fails + //if (NetEventSource.IsEnabled) NetEventSource.Error(this, SR.Format(SR.net_log_listener_cant_convert_to_utf8, decodedString, e.Message)); + } + + return false; + } + + private static void AppendOctetsPercentEncoded(StringBuilder target, IEnumerable<byte> octets) + { + foreach (byte octet in octets) + { + target.Append('%'); + target.Append(octet.ToString("X2", CultureInfo.InvariantCulture)); + } + } + + private static string GetOctetsAsString(IEnumerable<byte> octets) + { + StringBuilder octetString = new StringBuilder(); + + bool first = true; + foreach (byte octet in octets) + { + if (first) + { + first = false; + } + else + { + octetString.Append(' '); + } + octetString.Append(octet.ToString("X2", CultureInfo.InvariantCulture)); + } + + return octetString.ToString(); + } + + private static string GetPath(string uriString) + { + Debug.Assert(uriString != null, "uriString must not be null"); + Debug.Assert(uriString.Length > 0, "uriString must not be empty"); + + int pathStartIndex = 0; + + // Perf. improvement: nearly all strings are relative Uris. So just look if the + // string starts with '/'. If so, we have a relative Uri and the path starts at position 0. + // (http.sys already trimmed leading whitespaces) + if (uriString[0] != '/') + { + // We can't check against cookedUriScheme, since http.sys allows for request http://myserver/ to + // use a request line 'GET https://myserver/' (note http vs. https). Therefore check if the + // Uri starts with either http:// or https://. + int authorityStartIndex = 0; + if (uriString.StartsWith("http://", StringComparison.OrdinalIgnoreCase)) + { + authorityStartIndex = 7; + } + else if (uriString.StartsWith("https://", StringComparison.OrdinalIgnoreCase)) + { + authorityStartIndex = 8; + } + + if (authorityStartIndex > 0) + { + // we have an absolute Uri. Find out where the authority ends and the path begins. + // Note that Uris like "http://server?query=value/1/2" are invalid according to RFC2616 + // and http.sys behavior: If the Uri contains a query, there must be at least one '/' + // between the authority and the '?' character: It's safe to just look for the first + // '/' after the authority to determine the beginning of the path. + pathStartIndex = uriString.IndexOf('/', authorityStartIndex); + if (pathStartIndex == -1) + { + // e.g. for request lines like: 'GET http://myserver' (no final '/') + pathStartIndex = uriString.Length; + } + } + else + { + // RFC2616: Request-URI = "*" | absoluteURI | abs_path | authority + // 'authority' can only be used with CONNECT which is never received by HttpListener. + // I.e. if we don't have an absolute path (must start with '/') and we don't have + // an absolute Uri (must start with http:// or https://), then 'uriString' must be '*'. + Debug.Assert((uriString.Length == 1) && (uriString[0] == '*'), "Unknown request Uri string format", + "Request Uri string is not an absolute Uri, absolute path, or '*': {0}", uriString); + + // Should we ever get here, be consistent with 2.0/3.5 behavior: just add an initial + // slash to the string and treat it as a path: + uriString = "/" + uriString; + } + } + + // Find end of path: The path is terminated by + // - the first '?' character + // - the first '#' character: This is never the case here, since http.sys won't accept + // Uris containing fragments. Also, RFC2616 doesn't allow fragments in request Uris. + // - end of Uri string + int queryIndex = uriString.IndexOf('?'); + if (queryIndex == -1) + { + queryIndex = uriString.Length; + } + + // will always return a != null string. + return AddSlashToAsteriskOnlyPath(uriString.Substring(pathStartIndex, queryIndex - pathStartIndex)); + } + + private static string AddSlashToAsteriskOnlyPath(string path) + { + Debug.Assert(path != null, "'path' must not be null"); + + // If a request like "OPTIONS * HTTP/1.1" is sent to the listener, then the request Uri + // should be "http[s]://server[:port]/*" to be compatible with pre-4.0 behavior. + if ((path.Length == 1) && (path[0] == '*')) + { + return "/*"; + } + + return path; + } + + private enum ParsingResult + { + Success, + InvalidString, + EncodingError + } + + private enum EncodingType + { + Primary, + Secondary + } + } +} diff --git a/SocketHttpListener/Net/HttpResponseStream.Managed.cs b/SocketHttpListener/Net/HttpResponseStream.Managed.cs index 2de3fbb94..1a8a195fb 100644 --- a/SocketHttpListener/Net/HttpResponseStream.Managed.cs +++ b/SocketHttpListener/Net/HttpResponseStream.Managed.cs @@ -9,7 +9,6 @@ using System.Threading; using System.Threading.Tasks; using MediaBrowser.Model.IO; using MediaBrowser.Model.Logging; -using MediaBrowser.Model.Net; using MediaBrowser.Model.System; namespace SocketHttpListener.Net @@ -50,18 +49,18 @@ namespace SocketHttpListener.Net private bool _ignore_errors; private bool _trailer_sent; private Stream _stream; - private readonly IMemoryStreamFactory _memoryStreamFactory; + private readonly IStreamHelper _streamHelper; private readonly Socket _socket; private readonly bool _supportsDirectSocketAccess; private readonly IEnvironmentInfo _environment; private readonly IFileSystem _fileSystem; private readonly ILogger _logger; - internal HttpResponseStream(Stream stream, HttpListenerResponse response, bool ignore_errors, IMemoryStreamFactory memoryStreamFactory, Socket socket, bool supportsDirectSocketAccess, IEnvironmentInfo environment, IFileSystem fileSystem, ILogger logger) + internal HttpResponseStream(Stream stream, HttpListenerResponse response, bool ignore_errors, IStreamHelper streamHelper, Socket socket, bool supportsDirectSocketAccess, IEnvironmentInfo environment, IFileSystem fileSystem, ILogger logger) { _response = response; _ignore_errors = ignore_errors; - _memoryStreamFactory = memoryStreamFactory; + _streamHelper = streamHelper; _socket = socket; _supportsDirectSocketAccess = supportsDirectSocketAccess; _environment = environment; @@ -136,7 +135,7 @@ namespace SocketHttpListener.Net //{ // if (_response.HeadersSent) // return null; - // var ms = _memoryStreamFactory.CreateNew(); + // var ms = CreateNew(); // _response.SendHeaders(closing, ms); // return ms; //} @@ -287,64 +286,9 @@ namespace SocketHttpListener.Net public Task TransmitFile(string path, long offset, long count, FileShareMode fileShareMode, CancellationToken cancellationToken) { - //if (_supportsDirectSocketAccess && offset == 0 && count == 0 && !_response.SendChunked) - //{ - // return TransmitFileOverSocket(path, offset, count, fileShareMode, cancellationToken); - //} - return TransmitFileManaged(path, offset, count, fileShareMode, cancellationToken); } - private readonly byte[] _emptyBuffer = new byte[] { }; - private Task TransmitFileOverSocket(string path, long offset, long count, FileShareMode fileShareMode, CancellationToken cancellationToken) - { - var ms = GetHeaders(false); - - byte[] preBuffer; - if (ms != null) - { - using (var msCopy = new MemoryStream()) - { - ms.CopyTo(msCopy); - preBuffer = msCopy.ToArray(); - } - } - else - { - return TransmitFileManaged(path, offset, count, fileShareMode, cancellationToken); - } - - _stream.Flush(); - - _logger.Info("Socket sending file {0}", path); - - var taskCompletion = new TaskCompletionSource<bool>(); - - Action<IAsyncResult> callback = callbackResult => - { - try - { - _socket.EndSendFile(callbackResult); - taskCompletion.TrySetResult(true); - } - catch (Exception ex) - { - taskCompletion.TrySetException(ex); - } - }; - - var result = _socket.BeginSendFile(path, preBuffer, _emptyBuffer, TransmitFileOptions.UseDefaultWorkerThread, new AsyncCallback(callback), null); - - if (result.CompletedSynchronously) - { - callback(result); - } - - cancellationToken.Register(() => taskCompletion.TrySetCanceled()); - - return taskCompletion.Task; - } - const int StreamCopyToBufferSize = 81920; private async Task TransmitFileManaged(string path, long offset, long count, FileShareMode fileShareMode, CancellationToken cancellationToken) { @@ -375,71 +319,11 @@ namespace SocketHttpListener.Net if (count > 0) { - if (allowAsync) - { - await CopyToInternalAsync(fs, targetStream, count, cancellationToken).ConfigureAwait(false); - } - else - { - await CopyToInternalAsyncWithSyncRead(fs, targetStream, count, cancellationToken).ConfigureAwait(false); - } + await _streamHelper.CopyToAsync(fs, targetStream, count, cancellationToken).ConfigureAwait(false); } else { - if (allowAsync) - { - await fs.CopyToAsync(targetStream, StreamCopyToBufferSize, cancellationToken).ConfigureAwait(false); - } - else - { - fs.CopyTo(targetStream, StreamCopyToBufferSize); - } - } - } - } - - private static async Task CopyToInternalAsyncWithSyncRead(Stream source, Stream destination, long copyLength, CancellationToken cancellationToken) - { - var array = new byte[StreamCopyToBufferSize]; - int bytesRead; - - while ((bytesRead = source.Read(array, 0, array.Length)) != 0) - { - var bytesToWrite = Math.Min(bytesRead, copyLength); - - if (bytesToWrite > 0) - { - await destination.WriteAsync(array, 0, Convert.ToInt32(bytesToWrite), cancellationToken).ConfigureAwait(false); - } - - copyLength -= bytesToWrite; - - if (copyLength <= 0) - { - break; - } - } - } - - private static async Task CopyToInternalAsync(Stream source, Stream destination, long copyLength, CancellationToken cancellationToken) - { - var array = new byte[StreamCopyToBufferSize]; - int bytesRead; - - while ((bytesRead = await source.ReadAsync(array, 0, array.Length, cancellationToken).ConfigureAwait(false)) != 0) - { - var bytesToWrite = Math.Min(bytesRead, copyLength); - - if (bytesToWrite > 0) - { - await destination.WriteAsync(array, 0, Convert.ToInt32(bytesToWrite), cancellationToken).ConfigureAwait(false); - } - - copyLength -= bytesToWrite; - - if (copyLength <= 0) - { - break; + await fs.CopyToAsync(targetStream, StreamCopyToBufferSize, cancellationToken).ConfigureAwait(false); } } } diff --git a/SocketHttpListener/Net/ListenerPrefix.cs b/SocketHttpListener/Net/ListenerPrefix.cs index 605b7b88c..99bb118e5 100644 --- a/SocketHttpListener/Net/ListenerPrefix.cs +++ b/SocketHttpListener/Net/ListenerPrefix.cs @@ -4,50 +4,50 @@ using MediaBrowser.Model.Net; namespace SocketHttpListener.Net { - sealed class ListenerPrefix + internal sealed class ListenerPrefix { - string original; - string host; - ushort port; - string path; - bool secure; - IPAddress[] addresses; - public HttpListener Listener; + private string _original; + private string _host; + private ushort _port; + private string _path; + private bool _secure; + private IPAddress[] _addresses; + internal HttpListener _listener; public ListenerPrefix(string prefix) { - this.original = prefix; + _original = prefix; Parse(prefix); } public override string ToString() { - return original; + return _original; } public IPAddress[] Addresses { - get { return addresses; } - set { addresses = value; } + get { return _addresses; } + set { _addresses = value; } } public bool Secure { - get { return secure; } + get { return _secure; } } public string Host { - get { return host; } + get { return _host; } } public int Port { - get { return (int)port; } + get { return _port; } } public string Path { - get { return path; } + get { return _path; } } // Equals and GetHashCode are required to detect duplicates in HttpListenerPrefixCollection. @@ -57,92 +57,46 @@ namespace SocketHttpListener.Net if (other == null) return false; - return (original == other.original); + return (_original == other._original); } public override int GetHashCode() { - return original.GetHashCode(); + return _original.GetHashCode(); } - void Parse(string uri) + private void Parse(string uri) { ushort default_port = 80; if (uri.StartsWith("https://")) { default_port = 443; - secure = true; + _secure = true; } int length = uri.Length; int start_host = uri.IndexOf(':') + 3; if (start_host >= length) - throw new ArgumentException("No host specified."); + throw new ArgumentException("net_listener_host"); int colon = uri.IndexOf(':', start_host, length - start_host); int root; if (colon > 0) { - host = uri.Substring(start_host, colon - start_host); + _host = uri.Substring(start_host, colon - start_host); root = uri.IndexOf('/', colon, length - colon); - port = (ushort)Int32.Parse(uri.Substring(colon + 1, root - colon - 1)); - path = uri.Substring(root); + _port = (ushort)int.Parse(uri.Substring(colon + 1, root - colon - 1)); + _path = uri.Substring(root); } else { root = uri.IndexOf('/', start_host, length - start_host); - host = uri.Substring(start_host, root - start_host); - port = default_port; - path = uri.Substring(root); + _host = uri.Substring(start_host, root - start_host); + _port = default_port; + _path = uri.Substring(root); } - if (path.Length != 1) - path = path.Substring(0, path.Length - 1); - } - - public static void CheckUri(string uri) - { - if (uri == null) - throw new ArgumentNullException("uriPrefix"); - - if (!uri.StartsWith("http://") && !uri.StartsWith("https://")) - throw new ArgumentException("Only 'http' and 'https' schemes are supported."); - - int length = uri.Length; - int start_host = uri.IndexOf(':') + 3; - if (start_host >= length) - throw new ArgumentException("No host specified."); - - int colon = uri.IndexOf(':', start_host, length - start_host); - if (start_host == colon) - throw new ArgumentException("No host specified."); - - int root; - if (colon > 0) - { - root = uri.IndexOf('/', colon, length - colon); - if (root == -1) - throw new ArgumentException("No path specified."); - - try - { - int p = Int32.Parse(uri.Substring(colon + 1, root - colon - 1)); - if (p <= 0 || p >= 65536) - throw new Exception(); - } - catch - { - throw new ArgumentException("Invalid port."); - } - } - else - { - root = uri.IndexOf('/', start_host, length - start_host); - if (root == -1) - throw new ArgumentException("No path specified."); - } - - if (uri[uri.Length - 1] != '/') - throw new ArgumentException("The prefix must end with '/'"); + if (_path.Length != 1) + _path = _path.Substring(0, _path.Length - 1); } } } diff --git a/SocketHttpListener/Net/SocketAcceptor.cs b/SocketHttpListener/Net/SocketAcceptor.cs deleted file mode 100644 index 36332f52b..000000000 --- a/SocketHttpListener/Net/SocketAcceptor.cs +++ /dev/null @@ -1,124 +0,0 @@ -using System; -using System.Net.Sockets; -using MediaBrowser.Model.Logging; - -namespace SocketHttpListener.Net -{ - public class SocketAcceptor - { - private readonly ILogger _logger; - private readonly Socket _originalSocket; - private readonly Func<bool> _isClosed; - private readonly Action<Socket> _onAccept; - - public SocketAcceptor(ILogger logger, Socket originalSocket, Action<Socket> onAccept, Func<bool> isClosed) - { - if (logger == null) - { - throw new ArgumentNullException("logger"); - } - if (originalSocket == null) - { - throw new ArgumentNullException("originalSocket"); - } - if (onAccept == null) - { - throw new ArgumentNullException("onAccept"); - } - if (isClosed == null) - { - throw new ArgumentNullException("isClosed"); - } - - _logger = logger; - _originalSocket = originalSocket; - _isClosed = isClosed; - _onAccept = onAccept; - } - - public void StartAccept() - { - Socket dummy = null; - StartAccept(null, ref dummy); - } - - public void StartAccept(SocketAsyncEventArgs acceptEventArg, ref Socket accepted) - { - if (acceptEventArg == null) - { - acceptEventArg = new SocketAsyncEventArgs(); - acceptEventArg.Completed += new EventHandler<SocketAsyncEventArgs>(AcceptEventArg_Completed); - } - else - { - // acceptSocket must be cleared since the context object is being reused - acceptEventArg.AcceptSocket = null; - } - - try - { - bool willRaiseEvent = _originalSocket.AcceptAsync(acceptEventArg); - - if (!willRaiseEvent) - { - ProcessAccept(acceptEventArg); - } - } - catch (Exception ex) - { - if (accepted != null) - { - try - { -#if NET46 - accepted.Close(); -#else - accepted.Dispose(); -#endif - } - catch - { - } - accepted = null; - } - } - } - - // This method is the callback method associated with Socket.AcceptAsync - // operations and is invoked when an accept operation is complete - // - void AcceptEventArg_Completed(object sender, SocketAsyncEventArgs e) - { - ProcessAccept(e); - } - - private void ProcessAccept(SocketAsyncEventArgs e) - { - if (_isClosed()) - { - return; - } - - // http://msdn.microsoft.com/en-us/library/system.net.sockets.acceptSocket.acceptasync%28v=vs.110%29.aspx - // Under certain conditions ConnectionReset can occur - // Need to attept to re-accept - if (e.SocketError == SocketError.ConnectionReset) - { - _logger.Error("SocketError.ConnectionReset reported. Attempting to re-accept."); - Socket dummy = null; - StartAccept(e, ref dummy); - return; - } - - var acceptSocket = e.AcceptSocket; - if (acceptSocket != null) - { - //ProcessAccept(acceptSocket); - _onAccept(acceptSocket); - } - - // Accept the next connection request - StartAccept(e, ref acceptSocket); - } - } -} diff --git a/SocketHttpListener/Net/UriScheme.cs b/SocketHttpListener/Net/UriScheme.cs index 35b01e0e5..732fc0e7d 100644 --- a/SocketHttpListener/Net/UriScheme.cs +++ b/SocketHttpListener/Net/UriScheme.cs @@ -1,11 +1,10 @@ using System; using System.Collections.Generic; using System.Text; -using System.Threading.Tasks; namespace SocketHttpListener.Net { - internal class UriScheme + internal static class UriScheme { public const string File = "file"; public const string Ftp = "ftp"; diff --git a/SocketHttpListener/Net/WebHeaderCollection.cs b/SocketHttpListener/Net/WebHeaderCollection.cs index d82dc6816..4bed81404 100644 --- a/SocketHttpListener/Net/WebHeaderCollection.cs +++ b/SocketHttpListener/Net/WebHeaderCollection.cs @@ -35,8 +35,6 @@ namespace SocketHttpListener.Net }; static readonly Dictionary<string, HeaderInfo> headers; - HeaderInfo? headerRestriction; - HeaderInfo? headerConsistency; static WebHeaderCollection() { @@ -108,7 +106,6 @@ namespace SocketHttpListener.Net if (name == null) throw new ArgumentNullException("name"); - ThrowIfRestricted(name); this.AddWithoutValidate(name, value); } @@ -237,7 +234,6 @@ namespace SocketHttpListener.Net if (!IsHeaderValue(value)) throw new ArgumentException("invalid header value"); - ThrowIfRestricted(name); base.Set(name, value); } @@ -317,27 +313,6 @@ namespace SocketHttpListener.Net } } - // Private Methods - - public override int Remove(string name) - { - ThrowIfRestricted(name); - return base.Remove(name); - } - - protected void ThrowIfRestricted(string headerName) - { - if (!headerRestriction.HasValue) - return; - - HeaderInfo info; - if (!headers.TryGetValue(headerName, out info)) - return; - - if ((info & headerRestriction.Value) != 0) - throw new ArgumentException("This header must be modified with the appropriate property."); - } - internal static bool IsMultiValue(string headerName) { if (headerName == null) diff --git a/SocketHttpListener/Net/WebHeaderEncoding.cs b/SocketHttpListener/Net/WebHeaderEncoding.cs index 4a080179e..7290bfc63 100644 --- a/SocketHttpListener/Net/WebHeaderEncoding.cs +++ b/SocketHttpListener/Net/WebHeaderEncoding.cs @@ -83,48 +83,5 @@ namespace SocketHttpListener.Net } return bytes; } - - // The normal client header parser just casts bytes to chars (see GetString). - // Check if those bytes were actually utf-8 instead of ASCII. - // If not, just return the input value. - internal static string DecodeUtf8FromString(string input) - { - if (string.IsNullOrWhiteSpace(input)) - { - return input; - } - - bool possibleUtf8 = false; - for (int i = 0; i < input.Length; i++) - { - if (input[i] > (char)255) - { - return input; // This couldn't have come from the wire, someone assigned it directly. - } - else if (input[i] > (char)127) - { - possibleUtf8 = true; - break; - } - } - if (possibleUtf8) - { - byte[] rawBytes = new byte[input.Length]; - for (int i = 0; i < input.Length; i++) - { - if (input[i] > (char)255) - { - return input; // This couldn't have come from the wire, someone assigned it directly. - } - rawBytes[i] = (byte)input[i]; - } - try - { - return s_utf8Decoder.GetString(rawBytes); - } - catch (ArgumentException) { } // Not actually Utf-8 - } - return input; - } } } diff --git a/SocketHttpListener/Net/WebSockets/HttpListenerWebSocketContext.cs b/SocketHttpListener/Net/WebSockets/HttpListenerWebSocketContext.cs index 803c67b83..49375678d 100644 --- a/SocketHttpListener/Net/WebSockets/HttpListenerWebSocketContext.cs +++ b/SocketHttpListener/Net/WebSockets/HttpListenerWebSocketContext.cs @@ -12,337 +12,87 @@ using SocketHttpListener.Primitives; namespace SocketHttpListener.Net.WebSockets { - /// <summary> - /// Provides the properties used to access the information in a WebSocket connection request - /// received by the <see cref="HttpListener"/>. - /// </summary> - /// <remarks> - /// </remarks> public class HttpListenerWebSocketContext : WebSocketContext { - #region Private Fields + private readonly Uri _requestUri; + private readonly QueryParamCollection _headers; + private readonly CookieCollection _cookieCollection; + private readonly IPrincipal _user; + private readonly bool _isAuthenticated; + private readonly bool _isLocal; + private readonly bool _isSecureConnection; - private HttpListenerContext _context; - private WebSocket _websocket; + private readonly string _origin; + private readonly IEnumerable<string> _secWebSocketProtocols; + private readonly string _secWebSocketVersion; + private readonly string _secWebSocketKey; - #endregion - - #region Internal Constructors + private readonly WebSocket _webSocket; internal HttpListenerWebSocketContext( - HttpListenerContext context, string protocol, ICryptoProvider cryptoProvider, IMemoryStreamFactory memoryStreamFactory) + Uri requestUri, + QueryParamCollection headers, + CookieCollection cookieCollection, + IPrincipal user, + bool isAuthenticated, + bool isLocal, + bool isSecureConnection, + string origin, + IEnumerable<string> secWebSocketProtocols, + string secWebSocketVersion, + string secWebSocketKey, + WebSocket webSocket) { - _context = context; - _websocket = new WebSocket(this, protocol, cryptoProvider, memoryStreamFactory); - } - - #endregion + _cookieCollection = new CookieCollection(); + _cookieCollection.Add(cookieCollection); - #region Internal Properties + //_headers = new NameValueCollection(headers); + _headers = headers; + _user = CopyPrincipal(user); - internal Stream Stream - { - get - { - return _context.Connection.Stream; - } + _requestUri = requestUri; + _isAuthenticated = isAuthenticated; + _isLocal = isLocal; + _isSecureConnection = isSecureConnection; + _origin = origin; + _secWebSocketProtocols = secWebSocketProtocols; + _secWebSocketVersion = secWebSocketVersion; + _secWebSocketKey = secWebSocketKey; + _webSocket = webSocket; } - #endregion + public override Uri RequestUri => _requestUri; - #region Public Properties + public override QueryParamCollection Headers => _headers; - /// <summary> - /// Gets the HTTP cookies included in the request. - /// </summary> - /// <value> - /// A <see cref="System.Net.CookieCollection"/> that contains the cookies. - /// </value> - public override CookieCollection CookieCollection - { - get - { - return _context.Request.Cookies; - } - } + public override string Origin => _origin; - /// <summary> - /// Gets the HTTP headers included in the request. - /// </summary> - /// <value> - /// A <see cref="QueryParamCollection"/> that contains the headers. - /// </value> - public override QueryParamCollection Headers - { - get - { - return _context.Request.Headers; - } - } + public override IEnumerable<string> SecWebSocketProtocols => _secWebSocketProtocols; - /// <summary> - /// Gets the value of the Host header included in the request. - /// </summary> - /// <value> - /// A <see cref="string"/> that represents the value of the Host header. - /// </value> - public override string Host - { - get - { - return _context.Request.Headers["Host"]; - } - } + public override string SecWebSocketVersion => _secWebSocketVersion; - /// <summary> - /// Gets a value indicating whether the client is authenticated. - /// </summary> - /// <value> - /// <c>true</c> if the client is authenticated; otherwise, <c>false</c>. - /// </value> - public override bool IsAuthenticated - { - get - { - return _context.Request.IsAuthenticated; - } - } + public override string SecWebSocketKey => _secWebSocketKey; - /// <summary> - /// Gets a value indicating whether the client connected from the local computer. - /// </summary> - /// <value> - /// <c>true</c> if the client connected from the local computer; otherwise, <c>false</c>. - /// </value> - public override bool IsLocal - { - get - { - return _context.Request.IsLocal; - } - } + public override CookieCollection CookieCollection => _cookieCollection; - /// <summary> - /// Gets a value indicating whether the WebSocket connection is secured. - /// </summary> - /// <value> - /// <c>true</c> if the connection is secured; otherwise, <c>false</c>. - /// </value> - public override bool IsSecureConnection - { - get - { - return _context.Connection.IsSecure; - } - } + public override IPrincipal User => _user; - /// <summary> - /// Gets a value indicating whether the request is a WebSocket connection request. - /// </summary> - /// <value> - /// <c>true</c> if the request is a WebSocket connection request; otherwise, <c>false</c>. - /// </value> - public override bool IsWebSocketRequest - { - get - { - return _context.Request.IsWebSocketRequest; - } - } + public override bool IsAuthenticated => _isAuthenticated; - /// <summary> - /// Gets the value of the Origin header included in the request. - /// </summary> - /// <value> - /// A <see cref="string"/> that represents the value of the Origin header. - /// </value> - public override string Origin - { - get - { - return _context.Request.Headers["Origin"]; - } - } + public override bool IsLocal => _isLocal; - /// <summary> - /// Gets the query string included in the request. - /// </summary> - /// <value> - /// A <see cref="QueryParamCollection"/> that contains the query string parameters. - /// </value> - public override QueryParamCollection QueryString - { - get - { - return _context.Request.QueryString; - } - } + public override bool IsSecureConnection => _isSecureConnection; - /// <summary> - /// Gets the URI requested by the client. - /// </summary> - /// <value> - /// A <see cref="Uri"/> that represents the requested URI. - /// </value> - public override Uri RequestUri - { - get - { - return _context.Request.Url; - } - } + public override WebSocket WebSocket => _webSocket; - /// <summary> - /// Gets the value of the Sec-WebSocket-Key header included in the request. - /// </summary> - /// <remarks> - /// This property provides a part of the information used by the server to prove that it - /// received a valid WebSocket connection request. - /// </remarks> - /// <value> - /// A <see cref="string"/> that represents the value of the Sec-WebSocket-Key header. - /// </value> - public override string SecWebSocketKey + private static IPrincipal CopyPrincipal(IPrincipal user) { - get + if (user != null) { - return _context.Request.Headers["Sec-WebSocket-Key"]; + throw new NotImplementedException(); } - } - /// <summary> - /// Gets the values of the Sec-WebSocket-Protocol header included in the request. - /// </summary> - /// <remarks> - /// This property represents the subprotocols requested by the client. - /// </remarks> - /// <value> - /// An <see cref="T:System.Collections.Generic.IEnumerable{string}"/> instance that provides - /// an enumerator which supports the iteration over the values of the Sec-WebSocket-Protocol - /// header. - /// </value> - public override IEnumerable<string> SecWebSocketProtocols - { - get - { - var protocols = _context.Request.Headers["Sec-WebSocket-Protocol"]; - if (protocols != null) - foreach (var protocol in protocols.Split(',')) - yield return protocol.Trim(); - } - } - - /// <summary> - /// Gets the value of the Sec-WebSocket-Version header included in the request. - /// </summary> - /// <remarks> - /// This property represents the WebSocket protocol version. - /// </remarks> - /// <value> - /// A <see cref="string"/> that represents the value of the Sec-WebSocket-Version header. - /// </value> - public override string SecWebSocketVersion - { - get - { - return _context.Request.Headers["Sec-WebSocket-Version"]; - } - } - - /// <summary> - /// Gets the server endpoint as an IP address and a port number. - /// </summary> - /// <value> - /// </value> - public override IPEndPoint ServerEndPoint - { - get - { - return _context.Connection.LocalEndPoint; - } - } - - /// <summary> - /// Gets the client information (identity, authentication, and security roles). - /// </summary> - /// <value> - /// A <see cref="IPrincipal"/> that represents the client information. - /// </value> - public override IPrincipal User - { - get - { - return _context.User; - } - } - - /// <summary> - /// Gets the client endpoint as an IP address and a port number. - /// </summary> - /// <value> - /// </value> - public override IPEndPoint UserEndPoint - { - get - { - return _context.Connection.RemoteEndPoint; - } + return null; } - - /// <summary> - /// Gets the <see cref="SocketHttpListener.WebSocket"/> instance used for two-way communication - /// between client and server. - /// </summary> - /// <value> - /// A <see cref="SocketHttpListener.WebSocket"/>. - /// </value> - public override WebSocket WebSocket - { - get - { - return _websocket; - } - } - - #endregion - - #region Internal Methods - - internal void Close() - { - try - { - _context.Connection.Close(true); - } - catch - { - // catch errors sending the closing handshake - } - } - - internal void Close(HttpStatusCode code) - { - _context.Response.StatusCode = (int)code; - _context.Response.OutputStream.Dispose(); - } - - #endregion - - #region Public Methods - - /// <summary> - /// Returns a <see cref="string"/> that represents the current - /// <see cref="HttpListenerWebSocketContext"/>. - /// </summary> - /// <returns> - /// A <see cref="string"/> that represents the current - /// <see cref="HttpListenerWebSocketContext"/>. - /// </returns> - public override string ToString() - { - return _context.Request.ToString(); - } - - #endregion } } diff --git a/SocketHttpListener/Net/WebSockets/HttpWebSocket.Managed.cs b/SocketHttpListener/Net/WebSockets/HttpWebSocket.Managed.cs new file mode 100644 index 000000000..571e4bdba --- /dev/null +++ b/SocketHttpListener/Net/WebSockets/HttpWebSocket.Managed.cs @@ -0,0 +1,84 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading.Tasks; + +namespace SocketHttpListener.Net.WebSockets +{ + internal static partial class HttpWebSocket + { + private const string SupportedVersion = "13"; + + internal static async Task<HttpListenerWebSocketContext> AcceptWebSocketAsyncCore(HttpListenerContext context, + string subProtocol, + int receiveBufferSize, + TimeSpan keepAliveInterval, + ArraySegment<byte>? internalBuffer = null) + { + ValidateOptions(subProtocol, receiveBufferSize, MinSendBufferSize, keepAliveInterval); + + // get property will create a new response if one doesn't exist. + HttpListenerResponse response = context.Response; + HttpListenerRequest request = context.Request; + ValidateWebSocketHeaders(context); + + string secWebSocketVersion = request.Headers[HttpKnownHeaderNames.SecWebSocketVersion]; + + // Optional for non-browser client + string origin = request.Headers[HttpKnownHeaderNames.Origin]; + + string[] secWebSocketProtocols = null; + string outgoingSecWebSocketProtocolString; + bool shouldSendSecWebSocketProtocolHeader = + ProcessWebSocketProtocolHeader( + request.Headers[HttpKnownHeaderNames.SecWebSocketProtocol], + subProtocol, + out outgoingSecWebSocketProtocolString); + + if (shouldSendSecWebSocketProtocolHeader) + { + secWebSocketProtocols = new string[] { outgoingSecWebSocketProtocolString }; + response.Headers.Add(HttpKnownHeaderNames.SecWebSocketProtocol, outgoingSecWebSocketProtocolString); + } + + // negotiate the websocket key return value + string secWebSocketKey = request.Headers[HttpKnownHeaderNames.SecWebSocketKey]; + string secWebSocketAccept = HttpWebSocket.GetSecWebSocketAcceptString(secWebSocketKey); + + response.Headers.Add(HttpKnownHeaderNames.Connection, HttpKnownHeaderNames.Upgrade); + response.Headers.Add(HttpKnownHeaderNames.Upgrade, WebSocketUpgradeToken); + response.Headers.Add(HttpKnownHeaderNames.SecWebSocketAccept, secWebSocketAccept); + + response.StatusCode = (int)HttpStatusCode.SwitchingProtocols; // HTTP 101 + response.StatusDescription = HttpStatusDescription.Get(HttpStatusCode.SwitchingProtocols); + + HttpResponseStream responseStream = response.OutputStream as HttpResponseStream; + + // Send websocket handshake headers + await responseStream.WriteWebSocketHandshakeHeadersAsync().ConfigureAwait(false); + + //WebSocket webSocket = WebSocket.CreateFromStream(context.Connection.ConnectedStream, isServer: true, subProtocol, keepAliveInterval); + WebSocket webSocket = new WebSocket(subProtocol); + + HttpListenerWebSocketContext webSocketContext = new HttpListenerWebSocketContext( + request.Url, + request.Headers, + request.Cookies, + context.User, + request.IsAuthenticated, + request.IsLocal, + request.IsSecureConnection, + origin, + secWebSocketProtocols != null ? secWebSocketProtocols : Array.Empty<string>(), + secWebSocketVersion, + secWebSocketKey, + webSocket); + + webSocket.SetContext(webSocketContext, context.Connection.Close, context.Connection.Stream); + + return webSocketContext; + } + + private const bool WebSocketsSupported = true; + } +} diff --git a/SocketHttpListener/Net/WebSockets/HttpWebSocket.cs b/SocketHttpListener/Net/WebSockets/HttpWebSocket.cs new file mode 100644 index 000000000..9dc9143f8 --- /dev/null +++ b/SocketHttpListener/Net/WebSockets/HttpWebSocket.cs @@ -0,0 +1,160 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Diagnostics.CodeAnalysis; +using System.Security.Cryptography; +using System.Threading; + +namespace SocketHttpListener.Net.WebSockets +{ + internal static partial class HttpWebSocket + { + internal const string SecWebSocketKeyGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + internal const string WebSocketUpgradeToken = "websocket"; + internal const int DefaultReceiveBufferSize = 16 * 1024; + internal const int DefaultClientSendBufferSize = 16 * 1024; + + [SuppressMessage("Microsoft.Security", "CA5350", Justification = "SHA1 used only for hashing purposes, not for crypto.")] + internal static string GetSecWebSocketAcceptString(string secWebSocketKey) + { + string retVal; + + // SHA1 used only for hashing purposes, not for crypto. Check here for FIPS compat. + using (SHA1 sha1 = SHA1.Create()) + { + string acceptString = string.Concat(secWebSocketKey, HttpWebSocket.SecWebSocketKeyGuid); + byte[] toHash = Encoding.UTF8.GetBytes(acceptString); + retVal = Convert.ToBase64String(sha1.ComputeHash(toHash)); + } + + return retVal; + } + + // return value here signifies if a Sec-WebSocket-Protocol header should be returned by the server. + internal static bool ProcessWebSocketProtocolHeader(string clientSecWebSocketProtocol, + string subProtocol, + out string acceptProtocol) + { + acceptProtocol = string.Empty; + if (string.IsNullOrEmpty(clientSecWebSocketProtocol)) + { + // client hasn't specified any Sec-WebSocket-Protocol header + if (subProtocol != null) + { + // If the server specified _anything_ this isn't valid. + throw new WebSocketException("UnsupportedProtocol"); + } + // Treat empty and null from the server as the same thing here, server should not send headers. + return false; + } + + // here, we know the client specified something and it's non-empty. + + if (subProtocol == null) + { + // client specified some protocols, server specified 'null'. So server should send headers. + return true; + } + + // here, we know that the client has specified something, it's not empty + // and the server has specified exactly one protocol + + string[] requestProtocols = clientSecWebSocketProtocol.Split(new char[] { ',' }, + StringSplitOptions.RemoveEmptyEntries); + acceptProtocol = subProtocol; + + // client specified protocols, serverOptions has exactly 1 non-empty entry. Check that + // this exists in the list the client specified. + for (int i = 0; i < requestProtocols.Length; i++) + { + string currentRequestProtocol = requestProtocols[i].Trim(); + if (string.Equals(acceptProtocol, currentRequestProtocol, StringComparison.OrdinalIgnoreCase)) + { + return true; + } + } + + throw new WebSocketException("net_WebSockets_AcceptUnsupportedProtocol"); + } + + internal static void ValidateOptions(string subProtocol, int receiveBufferSize, int sendBufferSize, TimeSpan keepAliveInterval) + { + if (subProtocol != null) + { + WebSocketValidate.ValidateSubprotocol(subProtocol); + } + + if (receiveBufferSize < MinReceiveBufferSize) + { + throw new ArgumentOutOfRangeException("net_WebSockets_ArgumentOutOfRange_TooSmall"); + } + + if (sendBufferSize < MinSendBufferSize) + { + throw new ArgumentOutOfRangeException("net_WebSockets_ArgumentOutOfRange_TooSmall"); + } + + if (receiveBufferSize > MaxBufferSize) + { + throw new ArgumentOutOfRangeException("net_WebSockets_ArgumentOutOfRange_TooBig"); + } + + if (sendBufferSize > MaxBufferSize) + { + throw new ArgumentOutOfRangeException("net_WebSockets_ArgumentOutOfRange_TooBig"); + } + + if (keepAliveInterval < Timeout.InfiniteTimeSpan) // -1 millisecond + { + throw new ArgumentOutOfRangeException("net_WebSockets_ArgumentOutOfRange_TooSmall"); + } + } + + internal const int MinSendBufferSize = 16; + internal const int MinReceiveBufferSize = 256; + internal const int MaxBufferSize = 64 * 1024; + + private static void ValidateWebSocketHeaders(HttpListenerContext context) + { + if (!WebSocketsSupported) + { + throw new PlatformNotSupportedException("net_WebSockets_UnsupportedPlatform"); + } + + if (!context.Request.IsWebSocketRequest) + { + throw new WebSocketException("net_WebSockets_AcceptNotAWebSocket"); + } + + string secWebSocketVersion = context.Request.Headers[HttpKnownHeaderNames.SecWebSocketVersion]; + if (string.IsNullOrEmpty(secWebSocketVersion)) + { + throw new WebSocketException("net_WebSockets_AcceptHeaderNotFound"); + } + + if (!string.Equals(secWebSocketVersion, SupportedVersion, StringComparison.OrdinalIgnoreCase)) + { + throw new WebSocketException("net_WebSockets_AcceptUnsupportedWebSocketVersion"); + } + + string secWebSocketKey = context.Request.Headers[HttpKnownHeaderNames.SecWebSocketKey]; + bool isSecWebSocketKeyInvalid = string.IsNullOrWhiteSpace(secWebSocketKey); + if (!isSecWebSocketKeyInvalid) + { + try + { + // key must be 16 bytes then base64-encoded + isSecWebSocketKeyInvalid = Convert.FromBase64String(secWebSocketKey).Length != 16; + } + catch + { + isSecWebSocketKeyInvalid = true; + } + } + if (isSecWebSocketKeyInvalid) + { + throw new WebSocketException("net_WebSockets_AcceptHeaderNotFound"); + } + } + } +} diff --git a/SocketHttpListener/Net/WebSockets/WebSocketCloseStatus.cs b/SocketHttpListener/Net/WebSockets/WebSocketCloseStatus.cs new file mode 100644 index 000000000..0f43b7b80 --- /dev/null +++ b/SocketHttpListener/Net/WebSockets/WebSocketCloseStatus.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace SocketHttpListener.Net.WebSockets +{ + public enum WebSocketCloseStatus + { + NormalClosure = 1000, + EndpointUnavailable = 1001, + ProtocolError = 1002, + InvalidMessageType = 1003, + Empty = 1005, + // AbnormalClosure = 1006, // 1006 is reserved and should never be used by user + InvalidPayloadData = 1007, + PolicyViolation = 1008, + MessageTooBig = 1009, + MandatoryExtension = 1010, + InternalServerError = 1011 + // TLSHandshakeFailed = 1015, // 1015 is reserved and should never be used by user + + // 0 - 999 Status codes in the range 0-999 are not used. + // 1000 - 1999 Status codes in the range 1000-1999 are reserved for definition by this protocol. + // 2000 - 2999 Status codes in the range 2000-2999 are reserved for use by extensions. + // 3000 - 3999 Status codes in the range 3000-3999 MAY be used by libraries and frameworks. The + // interpretation of these codes is undefined by this protocol. End applications MUST + // NOT use status codes in this range. + // 4000 - 4999 Status codes in the range 4000-4999 MAY be used by application code. The interpretation + // of these codes is undefined by this protocol. + } +} diff --git a/SocketHttpListener/Net/WebSockets/WebSocketContext.cs b/SocketHttpListener/Net/WebSockets/WebSocketContext.cs index 9665ab789..071b5fe05 100644 --- a/SocketHttpListener/Net/WebSockets/WebSocketContext.cs +++ b/SocketHttpListener/Net/WebSockets/WebSocketContext.cs @@ -8,176 +8,19 @@ using MediaBrowser.Model.Services; namespace SocketHttpListener.Net.WebSockets { - /// <summary> - /// Exposes the properties used to access the information in a WebSocket connection request. - /// </summary> - /// <remarks> - /// The WebSocketContext class is an abstract class. - /// </remarks> public abstract class WebSocketContext { - #region Protected Constructors - - /// <summary> - /// Initializes a new instance of the <see cref="WebSocketContext"/> class. - /// </summary> - protected WebSocketContext() - { - } - - #endregion - - #region Public Properties - - /// <summary> - /// Gets the HTTP cookies included in the request. - /// </summary> - /// <value> - /// A <see cref="System.Net.CookieCollection"/> that contains the cookies. - /// </value> - public abstract CookieCollection CookieCollection { get; } - - /// <summary> - /// Gets the HTTP headers included in the request. - /// </summary> - /// <value> - /// A <see cref="QueryParamCollection"/> that contains the headers. - /// </value> + public abstract Uri RequestUri { get; } public abstract QueryParamCollection Headers { get; } - - /// <summary> - /// Gets the value of the Host header included in the request. - /// </summary> - /// <value> - /// A <see cref="string"/> that represents the value of the Host header. - /// </value> - public abstract string Host { get; } - - /// <summary> - /// Gets a value indicating whether the client is authenticated. - /// </summary> - /// <value> - /// <c>true</c> if the client is authenticated; otherwise, <c>false</c>. - /// </value> - public abstract bool IsAuthenticated { get; } - - /// <summary> - /// Gets a value indicating whether the client connected from the local computer. - /// </summary> - /// <value> - /// <c>true</c> if the client connected from the local computer; otherwise, <c>false</c>. - /// </value> - public abstract bool IsLocal { get; } - - /// <summary> - /// Gets a value indicating whether the WebSocket connection is secured. - /// </summary> - /// <value> - /// <c>true</c> if the connection is secured; otherwise, <c>false</c>. - /// </value> - public abstract bool IsSecureConnection { get; } - - /// <summary> - /// Gets a value indicating whether the request is a WebSocket connection request. - /// </summary> - /// <value> - /// <c>true</c> if the request is a WebSocket connection request; otherwise, <c>false</c>. - /// </value> - public abstract bool IsWebSocketRequest { get; } - - /// <summary> - /// Gets the value of the Origin header included in the request. - /// </summary> - /// <value> - /// A <see cref="string"/> that represents the value of the Origin header. - /// </value> public abstract string Origin { get; } - - /// <summary> - /// Gets the query string included in the request. - /// </summary> - /// <value> - /// A <see cref="QueryParamCollection"/> that contains the query string parameters. - /// </value> - public abstract QueryParamCollection QueryString { get; } - - /// <summary> - /// Gets the URI requested by the client. - /// </summary> - /// <value> - /// A <see cref="Uri"/> that represents the requested URI. - /// </value> - public abstract Uri RequestUri { get; } - - /// <summary> - /// Gets the value of the Sec-WebSocket-Key header included in the request. - /// </summary> - /// <remarks> - /// This property provides a part of the information used by the server to prove that it - /// received a valid WebSocket connection request. - /// </remarks> - /// <value> - /// A <see cref="string"/> that represents the value of the Sec-WebSocket-Key header. - /// </value> - public abstract string SecWebSocketKey { get; } - - /// <summary> - /// Gets the values of the Sec-WebSocket-Protocol header included in the request. - /// </summary> - /// <remarks> - /// This property represents the subprotocols requested by the client. - /// </remarks> - /// <value> - /// An <see cref="T:System.Collections.Generic.IEnumerable{string}"/> instance that provides - /// an enumerator which supports the iteration over the values of the Sec-WebSocket-Protocol - /// header. - /// </value> public abstract IEnumerable<string> SecWebSocketProtocols { get; } - - /// <summary> - /// Gets the value of the Sec-WebSocket-Version header included in the request. - /// </summary> - /// <remarks> - /// This property represents the WebSocket protocol version. - /// </remarks> - /// <value> - /// A <see cref="string"/> that represents the value of the Sec-WebSocket-Version header. - /// </value> public abstract string SecWebSocketVersion { get; } - - /// <summary> - /// Gets the server endpoint as an IP address and a port number. - /// </summary> - /// <value> - /// A <see cref="System.Net.IPEndPoint"/> that represents the server endpoint. - /// </value> - public abstract IPEndPoint ServerEndPoint { get; } - - /// <summary> - /// Gets the client information (identity, authentication, and security roles). - /// </summary> - /// <value> - /// A <see cref="IPrincipal"/> that represents the client information. - /// </value> + public abstract string SecWebSocketKey { get; } + public abstract CookieCollection CookieCollection { get; } public abstract IPrincipal User { get; } - - /// <summary> - /// Gets the client endpoint as an IP address and a port number. - /// </summary> - /// <value> - /// A <see cref="System.Net.IPEndPoint"/> that represents the client endpoint. - /// </value> - public abstract IPEndPoint UserEndPoint { get; } - - /// <summary> - /// Gets the <see cref="SocketHttpListener.WebSocket"/> instance used for two-way communication - /// between client and server. - /// </summary> - /// <value> - /// A <see cref="SocketHttpListener.WebSocket"/>. - /// </value> + public abstract bool IsAuthenticated { get; } + public abstract bool IsLocal { get; } + public abstract bool IsSecureConnection { get; } public abstract WebSocket WebSocket { get; } - - #endregion } } diff --git a/SocketHttpListener/Net/WebSockets/WebSocketValidate.cs b/SocketHttpListener/Net/WebSockets/WebSocketValidate.cs new file mode 100644 index 000000000..00895ea01 --- /dev/null +++ b/SocketHttpListener/Net/WebSockets/WebSocketValidate.cs @@ -0,0 +1,143 @@ +using System; +using System.Collections.Generic; +using System.Text; +using MediaBrowser.Model.Net; +using System.Globalization; +using WebSocketState = System.Net.WebSockets.WebSocketState; + +namespace SocketHttpListener.Net.WebSockets +{ + internal static partial class WebSocketValidate + { + internal const int MaxControlFramePayloadLength = 123; + private const int CloseStatusCodeAbort = 1006; + private const int CloseStatusCodeFailedTLSHandshake = 1015; + private const int InvalidCloseStatusCodesFrom = 0; + private const int InvalidCloseStatusCodesTo = 999; + private const string Separators = "()<>@,;:\\\"/[]?={} "; + + internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, WebSocketState[] validStates) + { + string validStatesText = string.Empty; + + if (validStates != null && validStates.Length > 0) + { + foreach (WebSocketState validState in validStates) + { + if (currentState == validState) + { + // Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior. + if (isDisposed) + { + throw new ObjectDisposedException(nameof(WebSocket)); + } + + return; + } + } + + validStatesText = string.Join(", ", validStates); + } + + throw new WebSocketException("net_WebSockets_InvalidState"); + } + + internal static void ValidateSubprotocol(string subProtocol) + { + if (string.IsNullOrWhiteSpace(subProtocol)) + { + throw new ArgumentException("net_WebSockets_InvalidEmptySubProtocol"); + } + + string invalidChar = null; + int i = 0; + while (i < subProtocol.Length) + { + char ch = subProtocol[i]; + if (ch < 0x21 || ch > 0x7e) + { + invalidChar = string.Format(CultureInfo.InvariantCulture, "[{0}]", (int)ch); + break; + } + + if (!char.IsLetterOrDigit(ch) && + Separators.IndexOf(ch) >= 0) + { + invalidChar = ch.ToString(); + break; + } + + i++; + } + + if (invalidChar != null) + { + throw new ArgumentException("net_WebSockets_InvalidCharInProtocolString"); + } + } + + internal static void ValidateCloseStatus(WebSocketCloseStatus closeStatus, string statusDescription) + { + if (closeStatus == WebSocketCloseStatus.Empty && !string.IsNullOrEmpty(statusDescription)) + { + throw new ArgumentException("net_WebSockets_ReasonNotNull"); + } + + int closeStatusCode = (int)closeStatus; + + if ((closeStatusCode >= InvalidCloseStatusCodesFrom && + closeStatusCode <= InvalidCloseStatusCodesTo) || + closeStatusCode == CloseStatusCodeAbort || + closeStatusCode == CloseStatusCodeFailedTLSHandshake) + { + // CloseStatus 1006 means Aborted - this will never appear on the wire and is reflected by calling WebSocket.Abort + throw new ArgumentException("net_WebSockets_InvalidCloseStatusCode"); + } + + int length = 0; + if (!string.IsNullOrEmpty(statusDescription)) + { + length = Encoding.UTF8.GetByteCount(statusDescription); + } + + if (length > MaxControlFramePayloadLength) + { + throw new ArgumentException("net_WebSockets_InvalidCloseStatusDescription"); + } + } + + internal static void ValidateArraySegment(ArraySegment<byte> arraySegment, string parameterName) + { + if (arraySegment.Array == null) + { + throw new ArgumentNullException(parameterName + "." + nameof(arraySegment.Array)); + } + if (arraySegment.Offset < 0 || arraySegment.Offset > arraySegment.Array.Length) + { + throw new ArgumentOutOfRangeException(parameterName + "." + nameof(arraySegment.Offset)); + } + if (arraySegment.Count < 0 || arraySegment.Count > (arraySegment.Array.Length - arraySegment.Offset)) + { + throw new ArgumentOutOfRangeException(parameterName + "." + nameof(arraySegment.Count)); + } + } + + internal static void ValidateBuffer(byte[] buffer, int offset, int count) + { + if (buffer == null) + { + throw new ArgumentNullException(nameof(buffer)); + } + + if (offset < 0 || offset > buffer.Length) + { + throw new ArgumentOutOfRangeException(nameof(offset)); + } + + if (count < 0 || count > (buffer.Length - offset)) + { + throw new ArgumentOutOfRangeException(nameof(count)); + } + } + } +} diff --git a/SocketHttpListener/Primitives/ITextEncoding.cs b/SocketHttpListener/Primitives/ITextEncoding.cs index 2c25a308c..a256a077d 100644 --- a/SocketHttpListener/Primitives/ITextEncoding.cs +++ b/SocketHttpListener/Primitives/ITextEncoding.cs @@ -12,5 +12,10 @@ namespace SocketHttpListener.Primitives { return Encoding.UTF8; } + + public static Encoding GetDefaultEncoding() + { + return Encoding.UTF8; + } } } diff --git a/SocketHttpListener/SocketHttpListener.csproj b/SocketHttpListener/SocketHttpListener.csproj index 6ed42ea88..da80fa94a 100644 --- a/SocketHttpListener/SocketHttpListener.csproj +++ b/SocketHttpListener/SocketHttpListener.csproj @@ -1,120 +1,18 @@ -<?xml version="1.0" encoding="utf-8"?> -<Project ToolsVersion="14.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> - <Import Project="$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props" Condition="Exists('$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props')" /> - <PropertyGroup> - <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration> - <Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform> - <ProjectGuid>{1D74413B-E7CF-455B-B021-F52BDF881542}</ProjectGuid> - <OutputType>Library</OutputType> - <AppDesignerFolder>Properties</AppDesignerFolder> - <RootNamespace>SocketHttpListener</RootNamespace> - <AssemblyName>SocketHttpListener</AssemblyName> - <TargetFrameworkVersion>v4.6</TargetFrameworkVersion> - <FileAlignment>512</FileAlignment> - <TargetFrameworkProfile /> - </PropertyGroup> - <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' "> - <DebugSymbols>true</DebugSymbols> - <DebugType>full</DebugType> - <Optimize>false</Optimize> - <OutputPath>bin\Debug\</OutputPath> - <DefineConstants>DEBUG;TRACE</DefineConstants> - <ErrorReport>prompt</ErrorReport> - <WarningLevel>4</WarningLevel> - <AllowUnsafeBlocks>true</AllowUnsafeBlocks> - </PropertyGroup> - <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' "> - <DebugType>pdbonly</DebugType> - <Optimize>true</Optimize> - <OutputPath>bin\Release\</OutputPath> - <DefineConstants>TRACE</DefineConstants> - <ErrorReport>prompt</ErrorReport> - <WarningLevel>4</WarningLevel> - <AllowUnsafeBlocks>true</AllowUnsafeBlocks> - </PropertyGroup> +<Project Sdk="Microsoft.NET.Sdk"> + <ItemGroup> - <Reference Include="System" /> - <Reference Include="System.Core" /> - <Reference Include="System.Xml.Linq" /> - <Reference Include="System.Data.DataSetExtensions" /> - <Reference Include="Microsoft.CSharp" /> - <Reference Include="System.Data" /> - <Reference Include="System.Net.Http" /> - <Reference Include="System.Xml" /> + <ProjectReference Include="..\MediaBrowser.Common\MediaBrowser.Common.csproj" /> + <ProjectReference Include="..\MediaBrowser.Model\MediaBrowser.Model.csproj" /> </ItemGroup> + <ItemGroup> - <Compile Include="..\SharedVersion.cs"> - <Link>Properties\SharedVersion.cs</Link> - </Compile> - <Compile Include="ByteOrder.cs" /> - <Compile Include="CloseEventArgs.cs" /> - <Compile Include="CloseStatusCode.cs" /> - <Compile Include="CompressionMethod.cs" /> - <Compile Include="ErrorEventArgs.cs" /> - <Compile Include="Ext.cs" /> - <Compile Include="Fin.cs" /> - <Compile Include="HttpBase.cs" /> - <Compile Include="HttpResponse.cs" /> - <Compile Include="Mask.cs" /> - <Compile Include="MessageEventArgs.cs" /> - <Compile Include="Net\AuthenticationSchemeSelector.cs" /> - <Compile Include="Net\BoundaryType.cs" /> - <Compile Include="Net\ChunkedInputStream.cs" /> - <Compile Include="Net\ChunkStream.cs" /> - <Compile Include="Net\CookieHelper.cs" /> - <Compile Include="Net\EndPointListener.cs" /> - <Compile Include="Net\EndPointManager.cs" /> - <Compile Include="Net\EntitySendFormat.cs" /> - <Compile Include="Net\HttpConnection.cs" /> - <Compile Include="Net\HttpListener.cs" /> - <Compile Include="Net\HttpListenerBasicIdentity.cs" /> - <Compile Include="Net\HttpListenerContext.cs" /> - <Compile Include="Net\HttpListenerPrefixCollection.cs" /> - <Compile Include="Net\HttpListenerRequest.cs" /> - <Compile Include="Net\HttpListenerResponse.Managed.cs" /> - <Compile Include="Net\HttpListenerResponse.cs" /> - <Compile Include="Net\HttpRequestStream.cs" /> - <Compile Include="Net\HttpRequestStream.Managed.cs" /> - <Compile Include="Net\HttpResponseStream.cs" /> - <Compile Include="Net\HttpResponseStream.Managed.cs" /> - <Compile Include="Net\HttpStatusCode.cs" /> - <Compile Include="Net\HttpStatusDescription.cs" /> - <Compile Include="Net\HttpStreamAsyncResult.cs" /> - <Compile Include="Net\HttpVersion.cs" /> - <Compile Include="Net\ListenerPrefix.cs" /> - <Compile Include="Net\SocketAcceptor.cs" /> - <Compile Include="Net\UriScheme.cs" /> - <Compile Include="Net\WebHeaderCollection.cs" /> - <Compile Include="Net\WebHeaderEncoding.cs" /> - <Compile Include="Net\WebSockets\HttpListenerWebSocketContext.cs" /> - <Compile Include="Net\WebSockets\WebSocketContext.cs" /> - <Compile Include="Opcode.cs" /> - <Compile Include="PayloadData.cs" /> - <Compile Include="Primitives\ITextEncoding.cs" /> - <Compile Include="Properties\AssemblyInfo.cs" /> - <Compile Include="Rsv.cs" /> - <Compile Include="SocketStream.cs" /> - <Compile Include="WebSocket.cs" /> - <Compile Include="WebSocketException.cs" /> - <Compile Include="WebSocketFrame.cs" /> - <Compile Include="WebSocketState.cs" /> + <Compile Include="..\SharedVersion.cs"/> </ItemGroup> - <ItemGroup> - <ProjectReference Include="..\MediaBrowser.Common\MediaBrowser.Common.csproj"> - <Project>{9142eefa-7570-41e1-bfcc-468bb571af2f}</Project> - <Name>MediaBrowser.Common</Name> - </ProjectReference> - <ProjectReference Include="..\MediaBrowser.Model\MediaBrowser.Model.csproj"> - <Project>{7eeeb4bb-f3e8-48fc-b4c5-70f0fff8329b}</Project> - <Name>MediaBrowser.Model</Name> - </ProjectReference> - </ItemGroup> - <Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" /> - <!-- To modify your build process, add your task inside one of the targets below and uncomment it. - Other similar extension points exist, see Microsoft.Common.targets. - <Target Name="BeforeBuild"> - </Target> - <Target Name="AfterBuild"> - </Target> - --> -</Project>
\ No newline at end of file + + <PropertyGroup> + <TargetFramework>netcoreapp2.1</TargetFramework> + <AllowUnsafeBlocks>true</AllowUnsafeBlocks> + <GenerateAssemblyInfo>false</GenerateAssemblyInfo> + </PropertyGroup> + +</Project> diff --git a/SocketHttpListener/WebSocket.cs b/SocketHttpListener/WebSocket.cs index 57c075e32..385b25aed 100644 --- a/SocketHttpListener/WebSocket.cs +++ b/SocketHttpListener/WebSocket.cs @@ -11,6 +11,8 @@ using MediaBrowser.Model.IO; using SocketHttpListener.Net.WebSockets; using SocketHttpListener.Primitives; using HttpStatusCode = SocketHttpListener.Net.HttpStatusCode; +using System.Net.Sockets; +using WebSocketState = System.Net.WebSockets.WebSocketState; namespace SocketHttpListener { @@ -30,7 +32,6 @@ namespace SocketHttpListener private CompressionMethod _compression; private WebSocketContext _context; private CookieCollection _cookies; - private string _extensions; private AutoResetEvent _exitReceiving; private object _forConn; private object _forEvent; @@ -52,9 +53,6 @@ namespace SocketHttpListener private Stream _stream; private Uri _uri; private const string _version = "13"; - private readonly IMemoryStreamFactory _memoryStreamFactory; - - private readonly ICryptoProvider _cryptoProvider; #endregion @@ -67,43 +65,29 @@ namespace SocketHttpListener #region Internal Constructors // As server - internal WebSocket(HttpListenerWebSocketContext context, string protocol, ICryptoProvider cryptoProvider, IMemoryStreamFactory memoryStreamFactory) + internal WebSocket(string protocol) { - _context = context; _protocol = protocol; - _cryptoProvider = cryptoProvider; - _memoryStreamFactory = memoryStreamFactory; + } + + public void SetContext(HttpListenerWebSocketContext context, Action closeContextFn, Stream stream) + { + _context = context; - _closeContext = context.Close; + _closeContext = closeContextFn; _secure = context.IsSecureConnection; - _stream = context.Stream; + _stream = stream; init(); } - #endregion - - // As server - internal Func<WebSocketContext, string> CustomHandshakeRequestChecker + public static TimeSpan DefaultKeepAliveInterval { - get - { - return _handshakeRequestChecker ?? (context => null); - } - - set - { - _handshakeRequestChecker = value; - } + // In the .NET Framework, this pulls the value from a P/Invoke. Here we just hardcode it to a reasonable default. + get { return TimeSpan.FromSeconds(30); } } - internal bool IsConnected - { - get - { - return _readyState == WebSocketState.Open || _readyState == WebSocketState.Closing; - } - } + #endregion /// <summary> /// Gets the state of the WebSocket connection. @@ -146,44 +130,6 @@ namespace SocketHttpListener #region Private Methods - // As server - private bool acceptHandshake() - { - var msg = checkIfValidHandshakeRequest(_context); - if (msg != null) - { - error("An error has occurred while connecting: " + msg); - Close(HttpStatusCode.BadRequest); - - return false; - } - - if (_protocol != null && - !_context.SecWebSocketProtocols.Contains(protocol => protocol == _protocol)) - _protocol = null; - - ////var extensions = _context.Headers["Sec-WebSocket-Extensions"]; - ////if (extensions != null && extensions.Length > 0) - //// processSecWebSocketExtensionsHeader(extensions); - - return sendHttpResponse(createHandshakeResponse()); - } - - // As server - private string checkIfValidHandshakeRequest(WebSocketContext context) - { - var headers = context.Headers; - return context.RequestUri == null - ? "Invalid request url." - : !context.IsWebSocketRequest - ? "Not WebSocket connection request." - : !validateSecWebSocketKeyHeader(headers["Sec-WebSocket-Key"]) - ? "Invalid Sec-WebSocket-Key header." - : !validateSecWebSocketVersionClientHeader(headers["Sec-WebSocket-Version"]) - ? "Invalid Sec-WebSocket-Version header." - : CustomHandshakeRequestChecker(context); - } - private void close(CloseStatusCode code, string reason, bool wait) { close(new PayloadData(((ushort)code).Append(reason)), !code.IsReserved(), wait); @@ -193,20 +139,19 @@ namespace SocketHttpListener { lock (_forConn) { - if (_readyState == WebSocketState.Closing || _readyState == WebSocketState.Closed) + if (_readyState == WebSocketState.CloseSent || _readyState == WebSocketState.Closed) { return; } - _readyState = WebSocketState.Closing; + _readyState = WebSocketState.CloseSent; } var e = new CloseEventArgs(payload); e.WasClean = closeHandshake( send ? WebSocketFrame.CreateCloseFrame(Mask.Unmask, payload).ToByteArray() : null, - wait ? 1000 : 0, - closeServerResources); + wait ? 1000 : 0); _readyState = WebSocketState.Closed; try @@ -219,14 +164,15 @@ namespace SocketHttpListener } } - private bool closeHandshake(byte[] frameAsBytes, int millisecondsTimeout, Action release) + private bool closeHandshake(byte[] frameAsBytes, int millisecondsTimeout) { var sent = frameAsBytes != null && writeBytes(frameAsBytes); var received = millisecondsTimeout == 0 || (sent && _exitReceiving != null && _exitReceiving.WaitOne(millisecondsTimeout)); - release(); + closeServerResources(); + if (_receivePong != null) { _receivePong.Dispose(); @@ -250,7 +196,15 @@ namespace SocketHttpListener if (_closeContext == null) return; - _closeContext(); + try + { + _closeContext(); + } + catch (SocketException) + { + // it could be unable to send the handshake response + } + _closeContext = null; _stream = null; _context = null; @@ -321,26 +275,6 @@ namespace SocketHttpListener return res; } - // As server - private HttpResponse createHandshakeResponse() - { - var res = HttpResponse.CreateWebSocketResponse(); - - var headers = res.Headers; - headers["Sec-WebSocket-Accept"] = CreateResponseKey(_base64Key); - - if (_protocol != null) - headers["Sec-WebSocket-Protocol"] = _protocol; - - if (_extensions != null) - headers["Sec-WebSocket-Extensions"] = _extensions; - - if (_cookies.Count > 0) - res.SetCookies(_cookies); - - return res; - } - private MessageEventArgs dequeueFromMessageEventQueue() { lock (_forMessageEventQueue) @@ -403,7 +337,10 @@ namespace SocketHttpListener { try { - OnOpen.Emit(this, EventArgs.Empty); + if (OnOpen != null) + { + OnOpen(this, EventArgs.Empty); + } } catch (Exception ex) { @@ -463,7 +400,7 @@ namespace SocketHttpListener private bool processFragments(WebSocketFrame first) { - using (var buff = _memoryStreamFactory.CreateNew()) + using (var buff = new MemoryStream()) { buff.WriteBytes(first.PayloadData.ApplicationData); if (!concatenateFragmentsInto(buff)) @@ -691,23 +628,6 @@ namespace SocketHttpListener receive(); } - // As server - private bool validateSecWebSocketKeyHeader(string value) - { - if (value == null || value.Length == 0) - return false; - - _base64Key = value; - return true; - } - - // As server - private bool validateSecWebSocketVersionClientHeader(string value) - { - return true; - //return value != null && value == _version; - } - private bool writeBytes(byte[] data) { try @@ -715,7 +635,7 @@ namespace SocketHttpListener _stream.Write(data, 0, data.Length); return true; } - catch (Exception ex) + catch (Exception) { return false; } @@ -728,9 +648,9 @@ namespace SocketHttpListener // As server internal void Close(HttpResponse response) { - _readyState = WebSocketState.Closing; - + _readyState = WebSocketState.CloseSent; sendHttpResponse(response); + closeServerResources(); _readyState = WebSocketState.Closed; @@ -747,11 +667,8 @@ namespace SocketHttpListener { try { - if (acceptHandshake()) - { - _readyState = WebSocketState.Open; - open(); - } + _readyState = WebSocketState.Open; + open(); } catch (Exception ex) { @@ -759,15 +676,6 @@ namespace SocketHttpListener } } - private string CreateResponseKey(string base64Key) - { - var buff = new StringBuilder(base64Key, 64); - buff.Append(_guid); - var src = _cryptoProvider.ComputeSHA1(Encoding.UTF8.GetBytes(buff.ToString())); - - return Convert.ToBase64String(src); - } - #endregion #region Public Methods @@ -830,18 +738,20 @@ namespace SocketHttpListener /// <param name="data"> /// An array of <see cref="byte"/> that represents the binary data to send. /// </param> - /// An Action<bool> delegate that references the method(s) called when the send is - /// complete. A <see cref="bool"/> passed to this delegate is <c>true</c> if the send is - /// complete successfully; otherwise, <c>false</c>. public Task SendAsync(byte[] data) { - var msg = _readyState.CheckIfOpen() ?? data.CheckIfValidSendData(); + if (data == null) + { + throw new ArgumentNullException("data"); + } + + var msg = _readyState.CheckIfOpen(); if (msg != null) { throw new Exception(msg); } - return sendAsync(Opcode.Binary, _memoryStreamFactory.CreateNew(data)); + return sendAsync(Opcode.Binary, new MemoryStream(data)); } /// <summary> @@ -853,18 +763,20 @@ namespace SocketHttpListener /// <param name="data"> /// A <see cref="string"/> that represents the text data to send. /// </param> - /// An Action<bool> delegate that references the method(s) called when the send is - /// complete. A <see cref="bool"/> passed to this delegate is <c>true</c> if the send is - /// complete successfully; otherwise, <c>false</c>. public Task SendAsync(string data) { - var msg = _readyState.CheckIfOpen() ?? data.CheckIfValidSendData(); + if (data == null) + { + throw new ArgumentNullException("data"); + } + + var msg = _readyState.CheckIfOpen(); if (msg != null) { throw new Exception(msg); } - return sendAsync(Opcode.Text, _memoryStreamFactory.CreateNew(Encoding.UTF8.GetBytes(data))); + return sendAsync(Opcode.Text, new MemoryStream(Encoding.UTF8.GetBytes(data))); } #endregion @@ -880,7 +792,6 @@ namespace SocketHttpListener void IDisposable.Dispose() { Close(CloseStatusCode.Away, null); - GC.SuppressFinalize(this); } #endregion diff --git a/SocketHttpListener/WebSocketState.cs b/SocketHttpListener/WebSocketState.cs deleted file mode 100644 index 73b3a49dd..000000000 --- a/SocketHttpListener/WebSocketState.cs +++ /dev/null @@ -1,35 +0,0 @@ -namespace SocketHttpListener -{ - /// <summary> - /// Contains the values of the state of the WebSocket connection. - /// </summary> - /// <remarks> - /// The values of the state are defined in - /// <see href="http://www.w3.org/TR/websockets/#dom-websocket-readystate">The WebSocket - /// API</see>. - /// </remarks> - public enum WebSocketState : ushort - { - /// <summary> - /// Equivalent to numeric value 0. - /// Indicates that the connection has not yet been established. - /// </summary> - Connecting = 0, - /// <summary> - /// Equivalent to numeric value 1. - /// Indicates that the connection is established and the communication is possible. - /// </summary> - Open = 1, - /// <summary> - /// Equivalent to numeric value 2. - /// Indicates that the connection is going through the closing handshake or - /// the <c>WebSocket.Close</c> method has been invoked. - /// </summary> - Closing = 2, - /// <summary> - /// Equivalent to numeric value 3. - /// Indicates that the connection has been closed or couldn't be opened. - /// </summary> - Closed = 3 - } -} |
