aboutsummaryrefslogtreecommitdiff
path: root/SocketHttpListener/Net/HttpConnection.cs
diff options
context:
space:
mode:
Diffstat (limited to 'SocketHttpListener/Net/HttpConnection.cs')
-rw-r--r--SocketHttpListener/Net/HttpConnection.cs532
1 files changed, 532 insertions, 0 deletions
diff --git a/SocketHttpListener/Net/HttpConnection.cs b/SocketHttpListener/Net/HttpConnection.cs
new file mode 100644
index 000000000..5beea5f22
--- /dev/null
+++ b/SocketHttpListener/Net/HttpConnection.cs
@@ -0,0 +1,532 @@
+using System;
+using System.IO;
+using System.Net;
+using System.Net.Security;
+using System.Net.Sockets;
+using System.Security.Authentication;
+using System.Security.Cryptography.X509Certificates;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using MediaBrowser.Model.Cryptography;
+using MediaBrowser.Model.IO;
+using MediaBrowser.Model.System;
+using Microsoft.Extensions.Logging;
+namespace SocketHttpListener.Net
+{
+ sealed class HttpConnection
+ {
+ private static AsyncCallback s_onreadCallback = new AsyncCallback(OnRead);
+ const int BufferSize = 8192;
+ Socket _socket;
+ Stream _stream;
+ HttpEndPointListener _epl;
+ MemoryStream _memoryStream;
+ byte[] _buffer;
+ HttpListenerContext _context;
+ StringBuilder _currentLine;
+ ListenerPrefix _prefix;
+ HttpRequestStream _requestStream;
+ HttpResponseStream _responseStream;
+ bool _chunked;
+ int _reuses;
+ bool _contextBound;
+ bool secure;
+ IPEndPoint local_ep;
+ HttpListener _lastListener;
+ X509Certificate cert;
+ SslStream ssl_stream;
+
+ private readonly ILogger _logger;
+ private readonly ICryptoProvider _cryptoProvider;
+ private readonly IStreamHelper _streamHelper;
+ private readonly IFileSystem _fileSystem;
+ private readonly IEnvironmentInfo _environment;
+
+ public HttpConnection(ILogger logger, Socket socket, HttpEndPointListener epl, bool secure,
+ X509Certificate cert, ICryptoProvider cryptoProvider, IStreamHelper streamHelper, IFileSystem fileSystem,
+ IEnvironmentInfo environment)
+ {
+ _logger = logger;
+ this._socket = socket;
+ this._epl = epl;
+ this.secure = secure;
+ this.cert = cert;
+ _cryptoProvider = cryptoProvider;
+ _streamHelper = streamHelper;
+ _fileSystem = fileSystem;
+ _environment = environment;
+
+ if (secure == false)
+ {
+ _stream = new SocketStream(_socket, false);
+ }
+ else
+ {
+ ssl_stream = new SslStream(new SocketStream(_socket, false), false, (t, c, ch, e) =>
+ {
+ if (c == null)
+ {
+ return true;
+ }
+
+ //var c2 = c as X509Certificate2;
+ //if (c2 == null)
+ //{
+ // c2 = new X509Certificate2(c.GetRawCertData());
+ //}
+
+ //_clientCert = c2;
+ //_clientCertErrors = new int[] { (int)e };
+ return true;
+ });
+
+ _stream = ssl_stream;
+ }
+ }
+
+ public Stream Stream => _stream;
+
+ public async Task Init()
+ {
+ if (ssl_stream != null)
+ {
+ var enableAsync = true;
+ if (enableAsync)
+ {
+ await ssl_stream.AuthenticateAsServerAsync(cert, false, (SslProtocols)ServicePointManager.SecurityProtocol, false).ConfigureAwait(false);
+ }
+ else
+ {
+ ssl_stream.AuthenticateAsServer(cert, false, (SslProtocols)ServicePointManager.SecurityProtocol, false);
+ }
+ }
+
+ InitInternal();
+ }
+
+ private void InitInternal()
+ {
+ _contextBound = false;
+ _requestStream = null;
+ _responseStream = null;
+ _prefix = null;
+ _chunked = false;
+ _memoryStream = new MemoryStream();
+ _position = 0;
+ _inputState = InputState.RequestLine;
+ _lineState = LineState.None;
+ _context = new HttpListenerContext(this);
+ }
+
+ public bool IsClosed => (_socket == null);
+
+ public int Reuses => _reuses;
+
+ public IPEndPoint LocalEndPoint
+ {
+ get
+ {
+ if (local_ep != null)
+ return local_ep;
+
+ local_ep = (IPEndPoint)_socket.LocalEndPoint;
+ return local_ep;
+ }
+ }
+
+ public IPEndPoint RemoteEndPoint => _socket.RemoteEndPoint as IPEndPoint;
+
+ public bool IsSecure => secure;
+
+ public ListenerPrefix Prefix
+ {
+ get => _prefix;
+ set => _prefix = value;
+ }
+
+ private void OnTimeout(object unused)
+ {
+ //_logger.LogInformation("HttpConnection timer fired");
+ CloseSocket();
+ Unbind();
+ }
+
+ public void BeginReadRequest()
+ {
+ if (_buffer == null)
+ {
+ _buffer = new byte[BufferSize];
+ }
+
+ try
+ {
+ _stream.BeginRead(_buffer, 0, BufferSize, s_onreadCallback, this);
+ }
+ catch
+ {
+ CloseSocket();
+ Unbind();
+ }
+ }
+
+ public HttpRequestStream GetRequestStream(bool chunked, long contentlength)
+ {
+ if (_requestStream == null)
+ {
+ byte[] buffer = _memoryStream.GetBuffer();
+ int length = (int)_memoryStream.Length;
+ _memoryStream = null;
+ if (chunked)
+ {
+ _chunked = true;
+ //_context.Response.SendChunked = true;
+ _requestStream = new ChunkedInputStream(_context, _stream, buffer, _position, length - _position);
+ }
+ else
+ {
+ _requestStream = new HttpRequestStream(_stream, buffer, _position, length - _position, contentlength);
+ }
+ }
+ return _requestStream;
+ }
+
+ public HttpResponseStream GetResponseStream(bool isExpect100Continue = false)
+ {
+ // TODO: can we get this _stream before reading the input?
+ if (_responseStream == null)
+ {
+ var supportsDirectSocketAccess = !_context.Response.SendChunked && !isExpect100Continue && !secure;
+
+ _responseStream = new HttpResponseStream(_stream, _context.Response, false, _streamHelper, _socket, supportsDirectSocketAccess, _environment, _fileSystem, _logger);
+ }
+ return _responseStream;
+ }
+
+ private static void OnRead(IAsyncResult ares)
+ {
+ var cnc = (HttpConnection)ares.AsyncState;
+ cnc.OnReadInternal(ares);
+ }
+
+ private void OnReadInternal(IAsyncResult ares)
+ {
+ int nread = -1;
+ try
+ {
+ nread = _stream.EndRead(ares);
+ _memoryStream.Write(_buffer, 0, nread);
+ if (_memoryStream.Length > 32768)
+ {
+ SendError("Bad Request", 400);
+ Close(true);
+ return;
+ }
+ }
+ catch
+ {
+ if (_memoryStream != null && _memoryStream.Length > 0)
+ {
+ SendError();
+ }
+
+ if (_socket != null)
+ {
+ CloseSocket();
+ Unbind();
+ }
+ return;
+ }
+
+ if (nread == 0)
+ {
+ CloseSocket();
+ Unbind();
+ return;
+ }
+
+ if (ProcessInput(_memoryStream))
+ {
+ if (!_context.HaveError)
+ _context.Request.FinishInitialization();
+
+ if (_context.HaveError)
+ {
+ SendError();
+ Close(true);
+ return;
+ }
+
+ if (!_epl.BindContext(_context))
+ {
+ const int NotFoundErrorCode = 404;
+ SendError(HttpStatusDescription.Get(NotFoundErrorCode), NotFoundErrorCode);
+ Close(true);
+ return;
+ }
+ HttpListener listener = _epl.Listener;
+ if (_lastListener != listener)
+ {
+ RemoveConnection();
+ listener.AddConnection(this);
+ _lastListener = listener;
+ }
+
+ _contextBound = true;
+ listener.RegisterContext(_context);
+ return;
+ }
+ _stream.BeginRead(_buffer, 0, BufferSize, s_onreadCallback, this);
+ }
+
+ private void RemoveConnection()
+ {
+ if (_lastListener == null)
+ _epl.RemoveConnection(this);
+ else
+ _lastListener.RemoveConnection(this);
+ }
+
+ private enum InputState
+ {
+ RequestLine,
+ Headers
+ }
+
+ private enum LineState
+ {
+ None,
+ CR,
+ LF
+ }
+
+ InputState _inputState = InputState.RequestLine;
+ LineState _lineState = LineState.None;
+ int _position;
+
+ // true -> done processing
+ // false -> need more input
+ private bool ProcessInput(MemoryStream ms)
+ {
+ byte[] buffer = ms.GetBuffer();
+ int len = (int)ms.Length;
+ int used = 0;
+ string line;
+
+ while (true)
+ {
+ if (_context.HaveError)
+ return true;
+
+ if (_position >= len)
+ break;
+
+ try
+ {
+ line = ReadLine(buffer, _position, len - _position, ref used);
+ _position += used;
+ }
+ catch
+ {
+ _context.ErrorMessage = "Bad request";
+ _context.ErrorStatus = 400;
+ return true;
+ }
+
+ if (line == null)
+ break;
+
+ if (line == "")
+ {
+ if (_inputState == InputState.RequestLine)
+ continue;
+ _currentLine = null;
+ ms = null;
+ return true;
+ }
+
+ if (_inputState == InputState.RequestLine)
+ {
+ _context.Request.SetRequestLine(line);
+ _inputState = InputState.Headers;
+ }
+ else
+ {
+ try
+ {
+ _context.Request.AddHeader(line);
+ }
+ catch (Exception e)
+ {
+ _context.ErrorMessage = e.Message;
+ _context.ErrorStatus = 400;
+ return true;
+ }
+ }
+ }
+
+ if (used == len)
+ {
+ ms.SetLength(0);
+ _position = 0;
+ }
+ return false;
+ }
+
+ private string ReadLine(byte[] buffer, int offset, int len, ref int used)
+ {
+ if (_currentLine == null)
+ _currentLine = new StringBuilder(128);
+ int last = offset + len;
+ used = 0;
+ for (int i = offset; i < last && _lineState != LineState.LF; i++)
+ {
+ used++;
+ byte b = buffer[i];
+ if (b == 13)
+ {
+ _lineState = LineState.CR;
+ }
+ else if (b == 10)
+ {
+ _lineState = LineState.LF;
+ }
+ else
+ {
+ _currentLine.Append((char)b);
+ }
+ }
+
+ string result = null;
+ if (_lineState == LineState.LF)
+ {
+ _lineState = LineState.None;
+ result = _currentLine.ToString();
+ _currentLine.Length = 0;
+ }
+
+ return result;
+ }
+
+ public void SendError(string msg, int status)
+ {
+ try
+ {
+ HttpListenerResponse response = _context.Response;
+ response.StatusCode = status;
+ response.ContentType = "text/html";
+ string description = HttpStatusDescription.Get(status);
+ string str;
+ if (msg != null)
+ str = string.Format("<h1>{0} ({1})</h1>", description, msg);
+ else
+ str = string.Format("<h1>{0}</h1>", description);
+
+ byte[] error = Encoding.UTF8.GetBytes(str);
+ response.Close(error, false);
+ }
+ catch
+ {
+ // response was already closed
+ }
+ }
+
+ public void SendError()
+ {
+ SendError(_context.ErrorMessage, _context.ErrorStatus);
+ }
+
+ private void Unbind()
+ {
+ if (_contextBound)
+ {
+ _epl.UnbindContext(_context);
+ _contextBound = false;
+ }
+ }
+
+ public void Close()
+ {
+ Close(false);
+ }
+
+ private void CloseSocket()
+ {
+ if (_socket == null)
+ return;
+
+ try
+ {
+ _socket.Close();
+ }
+ catch { }
+ finally
+ {
+ _socket = null;
+ }
+
+ RemoveConnection();
+ }
+
+ internal void Close(bool force)
+ {
+ if (_socket != null)
+ {
+ Stream st = GetResponseStream();
+ if (st != null)
+ st.Close();
+
+ _responseStream = null;
+ }
+
+ if (_socket != null)
+ {
+ force |= !_context.Request.KeepAlive;
+ if (!force)
+ {
+ force = string.Equals(_context.Response.Headers["connection"], "close", StringComparison.OrdinalIgnoreCase);
+ }
+
+ if (!force && _context.Request.FlushInput())
+ {
+ if (_chunked && _context.Response.ForceCloseChunked == false)
+ {
+ // Don't close. Keep working.
+ _reuses++;
+ Unbind();
+ InitInternal();
+ BeginReadRequest();
+ return;
+ }
+
+ _reuses++;
+ Unbind();
+ InitInternal();
+ BeginReadRequest();
+ return;
+ }
+
+ Socket s = _socket;
+ _socket = null;
+ try
+ {
+ s?.Shutdown(SocketShutdown.Both);
+ }
+ catch
+ {
+ }
+ finally
+ {
+ try
+ {
+ s?.Close();
+ }
+ catch { }
+ }
+ Unbind();
+ RemoveConnection();
+ return;
+ }
+ }
+ }
+}