diff options
| author | Luke Pulverenti <luke.pulverenti@gmail.com> | 2017-05-25 09:00:14 -0400 |
|---|---|---|
| committer | Luke Pulverenti <luke.pulverenti@gmail.com> | 2017-05-25 09:00:14 -0400 |
| commit | 28988b056ccc8efad54905b6f10ff0b9532c7130 (patch) | |
| tree | e5ef1b92cf28b884bb03bbfd67112a25e48a4fe7 /SocketHttpListener/Net/HttpConnection.cs | |
| parent | d035d7eaec937b1ad43af6a95f723070c1e847ea (diff) | |
update stream copying
Diffstat (limited to 'SocketHttpListener/Net/HttpConnection.cs')
| -rw-r--r-- | SocketHttpListener/Net/HttpConnection.cs | 377 |
1 files changed, 182 insertions, 195 deletions
diff --git a/SocketHttpListener/Net/HttpConnection.cs b/SocketHttpListener/Net/HttpConnection.cs index eda633207..9c87ff076 100644 --- a/SocketHttpListener/Net/HttpConnection.cs +++ b/SocketHttpListener/Net/HttpConnection.cs @@ -14,24 +14,25 @@ namespace SocketHttpListener.Net { sealed class HttpConnection { + private static AsyncCallback s_onreadCallback = new AsyncCallback(OnRead); const int BufferSize = 8192; - IAcceptSocket sock; - Stream stream; - EndPointListener epl; - MemoryStream ms; - byte[] buffer; - HttpListenerContext context; - StringBuilder current_line; - ListenerPrefix prefix; - HttpRequestStream i_stream; - Stream o_stream; - bool chunked; - int reuses; - bool context_bound; + IAcceptSocket _socket; + Stream _stream; + EndPointListener _epl; + MemoryStream _memoryStream; + byte[] _buffer; + HttpListenerContext _context; + StringBuilder _currentLine; + ListenerPrefix _prefix; + HttpRequestStream _requestStream; + Stream _responseStream; + bool _chunked; + int _reuses; + bool _contextBound; bool secure; - int s_timeout = 300000; // 90k ms for first request, 15k ms from then on + int _timeout = 300000; // 90k ms for first request, 15k ms from then on IpEndPointInfo local_ep; - HttpListener last_listener; + HttpListener _lastListener; int[] client_cert_errors; ICertificate cert; Stream ssl_stream; @@ -44,11 +45,11 @@ namespace SocketHttpListener.Net private readonly IFileSystem _fileSystem; private readonly IEnvironmentInfo _environment; - private HttpConnection(ILogger logger, IAcceptSocket sock, EndPointListener epl, bool secure, ICertificate cert, ICryptoProvider cryptoProvider, IStreamFactory streamFactory, IMemoryStreamFactory memoryStreamFactory, ITextEncoding textEncoding, IFileSystem fileSystem, IEnvironmentInfo environment) + private HttpConnection(ILogger logger, IAcceptSocket socket, EndPointListener epl, bool secure, ICertificate cert, ICryptoProvider cryptoProvider, IStreamFactory streamFactory, IMemoryStreamFactory memoryStreamFactory, ITextEncoding textEncoding, IFileSystem fileSystem, IEnvironmentInfo environment) { _logger = logger; - this.sock = sock; - this.epl = epl; + this._socket = socket; + this._epl = epl; this.secure = secure; this.cert = cert; _cryptoProvider = cryptoProvider; @@ -63,11 +64,11 @@ namespace SocketHttpListener.Net { if (secure == false) { - stream = _streamFactory.CreateNetworkStream(sock, false); + _stream = _streamFactory.CreateNetworkStream(_socket, false); } else { - //ssl_stream = epl.Listener.CreateSslStream(new NetworkStream(sock, false), false, (t, c, ch, e) => + //ssl_stream = _epl.Listener.CreateSslStream(new NetworkStream(_socket, false), false, (t, c, ch, e) => //{ // if (c == null) // return true; @@ -78,11 +79,11 @@ namespace SocketHttpListener.Net // client_cert_errors = new int[] { (int)e }; // return true; //}); - //stream = ssl_stream.AuthenticatedStream; + //_stream = ssl_stream.AuthenticatedStream; - ssl_stream = _streamFactory.CreateSslStream(_streamFactory.CreateNetworkStream(sock, false), false); + ssl_stream = _streamFactory.CreateSslStream(_streamFactory.CreateNetworkStream(_socket, false), false); await _streamFactory.AuthenticateSslStreamAsServer(ssl_stream, cert).ConfigureAwait(false); - stream = ssl_stream; + _stream = ssl_stream; } Init(); } @@ -100,7 +101,7 @@ namespace SocketHttpListener.Net { get { - return stream; + return _stream; } } @@ -111,32 +112,26 @@ namespace SocketHttpListener.Net void Init() { - if (ssl_stream != null) - { - //ssl_stream.AuthenticateAsServer(client_cert, true, (SslProtocols)ServicePointManager.SecurityProtocol, false); - //_streamFactory.AuthenticateSslStreamAsServer(ssl_stream, cert); - } - - context_bound = false; - i_stream = null; - o_stream = null; - prefix = null; - chunked = false; - ms = _memoryStreamFactory.CreateNew(); - position = 0; - input_state = InputState.RequestLine; - line_state = LineState.None; - context = new HttpListenerContext(this, _logger, _cryptoProvider, _memoryStreamFactory, _textEncoding, _fileSystem); + _contextBound = false; + _requestStream = null; + _responseStream = null; + _prefix = null; + _chunked = false; + _memoryStream = new MemoryStream(); + _position = 0; + _inputState = InputState.RequestLine; + _lineState = LineState.None; + _context = new HttpListenerContext(this, _logger, _cryptoProvider, _memoryStreamFactory, _textEncoding, _fileSystem); } public bool IsClosed { - get { return (sock == null); } + get { return (_socket == null); } } public int Reuses { - get { return reuses; } + get { return _reuses; } } public IpEndPointInfo LocalEndPoint @@ -146,14 +141,14 @@ namespace SocketHttpListener.Net if (local_ep != null) return local_ep; - local_ep = (IpEndPointInfo)sock.LocalEndPoint; + local_ep = (IpEndPointInfo)_socket.LocalEndPoint; return local_ep; } } public IpEndPointInfo RemoteEndPoint { - get { return (IpEndPointInfo)sock.RemoteEndPoint; } + get { return (IpEndPointInfo)_socket.RemoteEndPoint; } } public bool IsSecure @@ -163,187 +158,186 @@ namespace SocketHttpListener.Net public ListenerPrefix Prefix { - get { return prefix; } - set { prefix = value; } + get { return _prefix; } + set { _prefix = value; } } - public async Task BeginReadRequest() + public void BeginReadRequest() { - if (buffer == null) - buffer = new byte[BufferSize]; - + if (_buffer == null) + _buffer = new byte[BufferSize]; try { - //if (reuses == 1) - // s_timeout = 15000; - var nRead = await stream.ReadAsync(buffer, 0, BufferSize).ConfigureAwait(false); - - OnReadInternal(nRead); + if (_reuses == 1) + _timeout = 15000; + //_timer.Change(_timeout, Timeout.Infinite); + _stream.BeginRead(_buffer, 0, BufferSize, s_onreadCallback, this); } - catch (Exception ex) + catch { - OnReadInternalException(ms, ex); + //_timer.Change(Timeout.Infinite, Timeout.Infinite); + CloseSocket(); + Unbind(); } } public HttpRequestStream GetRequestStream(bool chunked, long contentlength) { - if (i_stream == null) + if (_requestStream == null) { - byte[] buffer; - _memoryStreamFactory.TryGetBuffer(ms, out buffer); - - int length = (int)ms.Length; - ms = null; + byte[] buffer = _memoryStream.GetBuffer(); + int length = (int)_memoryStream.Length; + _memoryStream = null; if (chunked) { - this.chunked = true; - //context.Response.SendChunked = true; - i_stream = new ChunkedInputStream(context, stream, buffer, position, length - position); + _chunked = true; + //_context.Response.SendChunked = true; + _requestStream = new ChunkedInputStream(_context, _stream, buffer, _position, length - _position); } else { - i_stream = new HttpRequestStream(stream, buffer, position, length - position, contentlength); + _requestStream = new HttpRequestStream(_stream, buffer, _position, length - _position, contentlength); } } - return i_stream; + return _requestStream; } public Stream GetResponseStream(bool isExpect100Continue = false) { - // TODO: can we get this stream before reading the input? - if (o_stream == null) + // TODO: can we get this _stream before reading the input? + if (_responseStream == null) { - //context.Response.DetermineIfChunked(); - - var supportsDirectSocketAccess = !context.Response.SendChunked && !isExpect100Continue && !secure; + var supportsDirectSocketAccess = !_context.Response.SendChunked && !isExpect100Continue && !secure; - //o_stream = new ResponseStream(stream, context.Response, _memoryStreamFactory, _textEncoding, _fileSystem, sock, supportsDirectSocketAccess, _logger, _environment); - - o_stream = new HttpResponseStream(stream, context.Response, false, _memoryStreamFactory, sock, supportsDirectSocketAccess, _environment, _fileSystem, _logger); + _responseStream = new HttpResponseStream(_stream, _context.Response, false, _memoryStreamFactory, _socket, supportsDirectSocketAccess, _environment, _fileSystem, _logger); } - return o_stream; + return _responseStream; } - void OnReadInternal(int nread) + private static void OnRead(IAsyncResult ares) { - ms.Write(buffer, 0, nread); - if (ms.Length > 32768) + HttpConnection cnc = (HttpConnection)ares.AsyncState; + cnc.OnReadInternal(ares); + } + + private void OnReadInternal(IAsyncResult ares) + { + //_timer.Change(Timeout.Infinite, Timeout.Infinite); + int nread = -1; + try + { + nread = _stream.EndRead(ares); + _memoryStream.Write(_buffer, 0, nread); + if (_memoryStream.Length > 32768) + { + SendError("Bad Request", 400); + Close(true); + return; + } + } + catch { - SendError("Bad request", 400); - Close(true); + if (_memoryStream != null && _memoryStream.Length > 0) + SendError(); + if (_socket != null) + { + CloseSocket(); + Unbind(); + } return; } if (nread == 0) { - //if (ms.Length > 0) - // SendError (); // Why bother? CloseSocket(); Unbind(); return; } - if (ProcessInput(ms)) + if (ProcessInput(_memoryStream)) { - if (!context.HaveError) - context.Request.FinishInitialization(); + if (!_context.HaveError) + _context.Request.FinishInitialization(); - if (context.HaveError) + if (_context.HaveError) { SendError(); Close(true); return; } - if (!epl.BindContext(context)) + if (!_epl.BindContext(_context)) { SendError("Invalid host", 400); Close(true); return; } - HttpListener listener = epl.Listener; - if (last_listener != listener) + HttpListener listener = _epl.Listener; + if (_lastListener != listener) { RemoveConnection(); listener.AddConnection(this); - last_listener = listener; + _lastListener = listener; } - context_bound = true; - listener.RegisterContext(context); + _contextBound = true; + listener.RegisterContext(_context); return; } - - BeginReadRequest(); - } - - private void OnReadInternalException(MemoryStream ms, Exception ex) - { - //_logger.ErrorException("Error in HttpConnection.OnReadInternal", ex); - - if (ms != null && ms.Length > 0) - SendError(); - if (sock != null) - { - CloseSocket(); - Unbind(); - } + _stream.BeginRead(_buffer, 0, BufferSize, s_onreadCallback, this); } - void RemoveConnection() + private void RemoveConnection() { - if (last_listener == null) - epl.RemoveConnection(this); + if (_lastListener == null) + _epl.RemoveConnection(this); else - last_listener.RemoveConnection(this); + _lastListener.RemoveConnection(this); } - enum InputState + private enum InputState { RequestLine, Headers } - enum LineState + private enum LineState { None, CR, LF } - InputState input_state = InputState.RequestLine; - LineState line_state = LineState.None; - int position; + InputState _inputState = InputState.RequestLine; + LineState _lineState = LineState.None; + int _position; // true -> done processing // false -> need more input - bool ProcessInput(MemoryStream ms) + private bool ProcessInput(MemoryStream ms) { - byte[] buffer; - _memoryStreamFactory.TryGetBuffer(ms, out buffer); - + byte[] buffer = ms.GetBuffer(); int len = (int)ms.Length; int used = 0; string line; while (true) { - if (context.HaveError) + if (_context.HaveError) return true; - if (position >= len) + if (_position >= len) break; try { - line = ReadLine(buffer, position, len - position, ref used); - position += used; + line = ReadLine(buffer, _position, len - _position, ref used); + _position += used; } catch { - context.ErrorMessage = "Bad request"; - context.ErrorStatus = 400; + _context.ErrorMessage = "Bad request"; + _context.ErrorStatus = 400; return true; } @@ -352,28 +346,28 @@ namespace SocketHttpListener.Net if (line == "") { - if (input_state == InputState.RequestLine) + if (_inputState == InputState.RequestLine) continue; - current_line = null; + _currentLine = null; ms = null; return true; } - if (input_state == InputState.RequestLine) + if (_inputState == InputState.RequestLine) { - context.Request.SetRequestLine(line); - input_state = InputState.Headers; + _context.Request.SetRequestLine(line); + _inputState = InputState.Headers; } else { try { - context.Request.AddHeader(line); + _context.Request.AddHeader(line); } catch (Exception e) { - context.ErrorMessage = e.Message; - context.ErrorStatus = 400; + _context.ErrorMessage = e.Message; + _context.ErrorStatus = 400; return true; } } @@ -382,42 +376,41 @@ namespace SocketHttpListener.Net if (used == len) { ms.SetLength(0); - position = 0; + _position = 0; } return false; } - string ReadLine(byte[] buffer, int offset, int len, ref int used) + private string ReadLine(byte[] buffer, int offset, int len, ref int used) { - if (current_line == null) - current_line = new StringBuilder(128); + if (_currentLine == null) + _currentLine = new StringBuilder(128); int last = offset + len; used = 0; - - for (int i = offset; i < last && line_state != LineState.LF; i++) + for (int i = offset; i < last && _lineState != LineState.LF; i++) { used++; byte b = buffer[i]; if (b == 13) { - line_state = LineState.CR; + _lineState = LineState.CR; } else if (b == 10) { - line_state = LineState.LF; + _lineState = LineState.LF; } else { - current_line.Append((char)b); + _currentLine.Append((char)b); } } string result = null; - if (line_state == LineState.LF) + if (_lineState == LineState.LF) { - line_state = LineState.None; - result = current_line.ToString(); - current_line.Length = 0; + _lineState = LineState.None; + result = _currentLine.ToString(); + _currentLine.Length = 0; } return result; @@ -427,20 +420,18 @@ namespace SocketHttpListener.Net { try { - HttpListenerResponse response = context.Response; + HttpListenerResponse response = _context.Response; response.StatusCode = status; response.ContentType = "text/html"; string description = HttpListenerResponse.GetStatusDescription(status); string str; if (msg != null) - str = String.Format("<h1>{0} ({1})</h1>", description, msg); + str = string.Format("<h1>{0} ({1})</h1>", description, msg); else - str = String.Format("<h1>{0}</h1>", description); + str = string.Format("<h1>{0}</h1>", description); - byte[] error = context.Response.ContentEncoding.GetBytes(str); - response.ContentLength64 = error.Length; - response.OutputStream.Write(error, 0, (int)error.Length); - response.Close(); + byte[] error = Encoding.Default.GetBytes(str); + response.Close(error, false); } catch { @@ -450,15 +441,15 @@ namespace SocketHttpListener.Net public void SendError() { - SendError(context.ErrorMessage, context.ErrorStatus); + SendError(_context.ErrorMessage, _context.ErrorStatus); } - void Unbind() + private void Unbind() { - if (context_bound) + if (_contextBound) { - epl.UnbindContext(context); - context_bound = false; + _epl.UnbindContext(_context); + _contextBound = false; } } @@ -469,64 +460,60 @@ namespace SocketHttpListener.Net private void CloseSocket() { - if (sock == null) + if (_socket == null) return; try { - sock.Close(); - } - catch - { + _socket.Close(); } + catch { } finally { - sock = null; + _socket = null; } + RemoveConnection(); } - internal void Close(bool force_close) + internal void Close(bool force) { - if (sock != null) + if (_socket != null) { - if (!context.Request.IsWebSocketRequest || force_close) - { - Stream st = GetResponseStream(); - if (st != null) - { - st.Dispose(); - } + Stream st = GetResponseStream(); + if (st != null) + st.Close(); - o_stream = null; - } + _responseStream = null; } - if (sock != null) + if (_socket != null) { - force_close |= !context.Request.KeepAlive; - if (!force_close) - force_close = (string.Equals(context.Response.Headers["connection"], "close", StringComparison.OrdinalIgnoreCase)); - /* - if (!force_close) { -// bool conn_close = (status_code == 400 || status_code == 408 || status_code == 411 || -// status_code == 413 || status_code == 414 || status_code == 500 || -// status_code == 503); - force_close |= (context.Request.ProtocolVersion <= HttpVersion.Version10); - } - */ - - if (!force_close && context.Request.FlushInput()) + force |= !_context.Request.KeepAlive; + if (!force) + force = (string.Equals(_context.Response.Headers["connection"], "close", StringComparison.OrdinalIgnoreCase)); + + if (!force && _context.Request.FlushInput()) { - reuses++; + if (_chunked && _context.Response.ForceCloseChunked == false) + { + // Don't close. Keep working. + _reuses++; + Unbind(); + Init(); + BeginReadRequest(); + return; + } + + _reuses++; Unbind(); Init(); BeginReadRequest(); return; } - IAcceptSocket s = sock; - sock = null; + IAcceptSocket s = _socket; + _socket = null; try { if (s != null) |
