diff options
Diffstat (limited to 'SocketHttpListener/Net/HttpConnection.cs')
| -rw-r--r-- | SocketHttpListener/Net/HttpConnection.cs | 532 |
1 files changed, 532 insertions, 0 deletions
diff --git a/SocketHttpListener/Net/HttpConnection.cs b/SocketHttpListener/Net/HttpConnection.cs new file mode 100644 index 000000000..5beea5f22 --- /dev/null +++ b/SocketHttpListener/Net/HttpConnection.cs @@ -0,0 +1,532 @@ +using System; +using System.IO; +using System.Net; +using System.Net.Security; +using System.Net.Sockets; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using MediaBrowser.Model.Cryptography; +using MediaBrowser.Model.IO; +using MediaBrowser.Model.System; +using Microsoft.Extensions.Logging; +namespace SocketHttpListener.Net +{ + sealed class HttpConnection + { + private static AsyncCallback s_onreadCallback = new AsyncCallback(OnRead); + const int BufferSize = 8192; + Socket _socket; + Stream _stream; + HttpEndPointListener _epl; + MemoryStream _memoryStream; + byte[] _buffer; + HttpListenerContext _context; + StringBuilder _currentLine; + ListenerPrefix _prefix; + HttpRequestStream _requestStream; + HttpResponseStream _responseStream; + bool _chunked; + int _reuses; + bool _contextBound; + bool secure; + IPEndPoint local_ep; + HttpListener _lastListener; + X509Certificate cert; + SslStream ssl_stream; + + private readonly ILogger _logger; + private readonly ICryptoProvider _cryptoProvider; + private readonly IStreamHelper _streamHelper; + private readonly IFileSystem _fileSystem; + private readonly IEnvironmentInfo _environment; + + public HttpConnection(ILogger logger, Socket socket, HttpEndPointListener epl, bool secure, + X509Certificate cert, ICryptoProvider cryptoProvider, IStreamHelper streamHelper, IFileSystem fileSystem, + IEnvironmentInfo environment) + { + _logger = logger; + this._socket = socket; + this._epl = epl; + this.secure = secure; + this.cert = cert; + _cryptoProvider = cryptoProvider; + _streamHelper = streamHelper; + _fileSystem = fileSystem; + _environment = environment; + + if (secure == false) + { + _stream = new SocketStream(_socket, false); + } + else + { + ssl_stream = new SslStream(new SocketStream(_socket, false), false, (t, c, ch, e) => + { + if (c == null) + { + return true; + } + + //var c2 = c as X509Certificate2; + //if (c2 == null) + //{ + // c2 = new X509Certificate2(c.GetRawCertData()); + //} + + //_clientCert = c2; + //_clientCertErrors = new int[] { (int)e }; + return true; + }); + + _stream = ssl_stream; + } + } + + public Stream Stream => _stream; + + public async Task Init() + { + if (ssl_stream != null) + { + var enableAsync = true; + if (enableAsync) + { + await ssl_stream.AuthenticateAsServerAsync(cert, false, (SslProtocols)ServicePointManager.SecurityProtocol, false).ConfigureAwait(false); + } + else + { + ssl_stream.AuthenticateAsServer(cert, false, (SslProtocols)ServicePointManager.SecurityProtocol, false); + } + } + + InitInternal(); + } + + private void InitInternal() + { + _contextBound = false; + _requestStream = null; + _responseStream = null; + _prefix = null; + _chunked = false; + _memoryStream = new MemoryStream(); + _position = 0; + _inputState = InputState.RequestLine; + _lineState = LineState.None; + _context = new HttpListenerContext(this); + } + + public bool IsClosed => (_socket == null); + + public int Reuses => _reuses; + + public IPEndPoint LocalEndPoint + { + get + { + if (local_ep != null) + return local_ep; + + local_ep = (IPEndPoint)_socket.LocalEndPoint; + return local_ep; + } + } + + public IPEndPoint RemoteEndPoint => _socket.RemoteEndPoint as IPEndPoint; + + public bool IsSecure => secure; + + public ListenerPrefix Prefix + { + get => _prefix; + set => _prefix = value; + } + + private void OnTimeout(object unused) + { + //_logger.LogInformation("HttpConnection timer fired"); + CloseSocket(); + Unbind(); + } + + public void BeginReadRequest() + { + if (_buffer == null) + { + _buffer = new byte[BufferSize]; + } + + try + { + _stream.BeginRead(_buffer, 0, BufferSize, s_onreadCallback, this); + } + catch + { + CloseSocket(); + Unbind(); + } + } + + public HttpRequestStream GetRequestStream(bool chunked, long contentlength) + { + if (_requestStream == null) + { + byte[] buffer = _memoryStream.GetBuffer(); + int length = (int)_memoryStream.Length; + _memoryStream = null; + if (chunked) + { + _chunked = true; + //_context.Response.SendChunked = true; + _requestStream = new ChunkedInputStream(_context, _stream, buffer, _position, length - _position); + } + else + { + _requestStream = new HttpRequestStream(_stream, buffer, _position, length - _position, contentlength); + } + } + return _requestStream; + } + + public HttpResponseStream GetResponseStream(bool isExpect100Continue = false) + { + // TODO: can we get this _stream before reading the input? + if (_responseStream == null) + { + var supportsDirectSocketAccess = !_context.Response.SendChunked && !isExpect100Continue && !secure; + + _responseStream = new HttpResponseStream(_stream, _context.Response, false, _streamHelper, _socket, supportsDirectSocketAccess, _environment, _fileSystem, _logger); + } + return _responseStream; + } + + private static void OnRead(IAsyncResult ares) + { + var cnc = (HttpConnection)ares.AsyncState; + cnc.OnReadInternal(ares); + } + + private void OnReadInternal(IAsyncResult ares) + { + int nread = -1; + try + { + nread = _stream.EndRead(ares); + _memoryStream.Write(_buffer, 0, nread); + if (_memoryStream.Length > 32768) + { + SendError("Bad Request", 400); + Close(true); + return; + } + } + catch + { + if (_memoryStream != null && _memoryStream.Length > 0) + { + SendError(); + } + + if (_socket != null) + { + CloseSocket(); + Unbind(); + } + return; + } + + if (nread == 0) + { + CloseSocket(); + Unbind(); + return; + } + + if (ProcessInput(_memoryStream)) + { + if (!_context.HaveError) + _context.Request.FinishInitialization(); + + if (_context.HaveError) + { + SendError(); + Close(true); + return; + } + + if (!_epl.BindContext(_context)) + { + const int NotFoundErrorCode = 404; + SendError(HttpStatusDescription.Get(NotFoundErrorCode), NotFoundErrorCode); + Close(true); + return; + } + HttpListener listener = _epl.Listener; + if (_lastListener != listener) + { + RemoveConnection(); + listener.AddConnection(this); + _lastListener = listener; + } + + _contextBound = true; + listener.RegisterContext(_context); + return; + } + _stream.BeginRead(_buffer, 0, BufferSize, s_onreadCallback, this); + } + + private void RemoveConnection() + { + if (_lastListener == null) + _epl.RemoveConnection(this); + else + _lastListener.RemoveConnection(this); + } + + private enum InputState + { + RequestLine, + Headers + } + + private enum LineState + { + None, + CR, + LF + } + + InputState _inputState = InputState.RequestLine; + LineState _lineState = LineState.None; + int _position; + + // true -> done processing + // false -> need more input + private bool ProcessInput(MemoryStream ms) + { + byte[] buffer = ms.GetBuffer(); + int len = (int)ms.Length; + int used = 0; + string line; + + while (true) + { + if (_context.HaveError) + return true; + + if (_position >= len) + break; + + try + { + line = ReadLine(buffer, _position, len - _position, ref used); + _position += used; + } + catch + { + _context.ErrorMessage = "Bad request"; + _context.ErrorStatus = 400; + return true; + } + + if (line == null) + break; + + if (line == "") + { + if (_inputState == InputState.RequestLine) + continue; + _currentLine = null; + ms = null; + return true; + } + + if (_inputState == InputState.RequestLine) + { + _context.Request.SetRequestLine(line); + _inputState = InputState.Headers; + } + else + { + try + { + _context.Request.AddHeader(line); + } + catch (Exception e) + { + _context.ErrorMessage = e.Message; + _context.ErrorStatus = 400; + return true; + } + } + } + + if (used == len) + { + ms.SetLength(0); + _position = 0; + } + return false; + } + + private string ReadLine(byte[] buffer, int offset, int len, ref int used) + { + if (_currentLine == null) + _currentLine = new StringBuilder(128); + int last = offset + len; + used = 0; + for (int i = offset; i < last && _lineState != LineState.LF; i++) + { + used++; + byte b = buffer[i]; + if (b == 13) + { + _lineState = LineState.CR; + } + else if (b == 10) + { + _lineState = LineState.LF; + } + else + { + _currentLine.Append((char)b); + } + } + + string result = null; + if (_lineState == LineState.LF) + { + _lineState = LineState.None; + result = _currentLine.ToString(); + _currentLine.Length = 0; + } + + return result; + } + + public void SendError(string msg, int status) + { + try + { + HttpListenerResponse response = _context.Response; + response.StatusCode = status; + response.ContentType = "text/html"; + string description = HttpStatusDescription.Get(status); + string str; + if (msg != null) + str = string.Format("<h1>{0} ({1})</h1>", description, msg); + else + str = string.Format("<h1>{0}</h1>", description); + + byte[] error = Encoding.UTF8.GetBytes(str); + response.Close(error, false); + } + catch + { + // response was already closed + } + } + + public void SendError() + { + SendError(_context.ErrorMessage, _context.ErrorStatus); + } + + private void Unbind() + { + if (_contextBound) + { + _epl.UnbindContext(_context); + _contextBound = false; + } + } + + public void Close() + { + Close(false); + } + + private void CloseSocket() + { + if (_socket == null) + return; + + try + { + _socket.Close(); + } + catch { } + finally + { + _socket = null; + } + + RemoveConnection(); + } + + internal void Close(bool force) + { + if (_socket != null) + { + Stream st = GetResponseStream(); + if (st != null) + st.Close(); + + _responseStream = null; + } + + if (_socket != null) + { + force |= !_context.Request.KeepAlive; + if (!force) + { + force = string.Equals(_context.Response.Headers["connection"], "close", StringComparison.OrdinalIgnoreCase); + } + + if (!force && _context.Request.FlushInput()) + { + if (_chunked && _context.Response.ForceCloseChunked == false) + { + // Don't close. Keep working. + _reuses++; + Unbind(); + InitInternal(); + BeginReadRequest(); + return; + } + + _reuses++; + Unbind(); + InitInternal(); + BeginReadRequest(); + return; + } + + Socket s = _socket; + _socket = null; + try + { + s?.Shutdown(SocketShutdown.Both); + } + catch + { + } + finally + { + try + { + s?.Close(); + } + catch { } + } + Unbind(); + RemoveConnection(); + return; + } + } + } +} |
