update stream copying

This commit is contained in:
Luke Pulverenti
2017-05-25 09:00:14 -04:00
parent d035d7eaec
commit 28988b056c
12 changed files with 468 additions and 406 deletions

View File

@@ -14,24 +14,25 @@ namespace SocketHttpListener.Net
{
sealed class HttpConnection
{
private static AsyncCallback s_onreadCallback = new AsyncCallback(OnRead);
const int BufferSize = 8192;
IAcceptSocket sock;
Stream stream;
EndPointListener epl;
MemoryStream ms;
byte[] buffer;
HttpListenerContext context;
StringBuilder current_line;
ListenerPrefix prefix;
HttpRequestStream i_stream;
Stream o_stream;
bool chunked;
int reuses;
bool context_bound;
IAcceptSocket _socket;
Stream _stream;
EndPointListener _epl;
MemoryStream _memoryStream;
byte[] _buffer;
HttpListenerContext _context;
StringBuilder _currentLine;
ListenerPrefix _prefix;
HttpRequestStream _requestStream;
Stream _responseStream;
bool _chunked;
int _reuses;
bool _contextBound;
bool secure;
int s_timeout = 300000; // 90k ms for first request, 15k ms from then on
int _timeout = 300000; // 90k ms for first request, 15k ms from then on
IpEndPointInfo local_ep;
HttpListener last_listener;
HttpListener _lastListener;
int[] client_cert_errors;
ICertificate cert;
Stream ssl_stream;
@@ -44,11 +45,11 @@ namespace SocketHttpListener.Net
private readonly IFileSystem _fileSystem;
private readonly IEnvironmentInfo _environment;
private HttpConnection(ILogger logger, IAcceptSocket sock, EndPointListener epl, bool secure, ICertificate cert, ICryptoProvider cryptoProvider, IStreamFactory streamFactory, IMemoryStreamFactory memoryStreamFactory, ITextEncoding textEncoding, IFileSystem fileSystem, IEnvironmentInfo environment)
private HttpConnection(ILogger logger, IAcceptSocket socket, EndPointListener epl, bool secure, ICertificate cert, ICryptoProvider cryptoProvider, IStreamFactory streamFactory, IMemoryStreamFactory memoryStreamFactory, ITextEncoding textEncoding, IFileSystem fileSystem, IEnvironmentInfo environment)
{
_logger = logger;
this.sock = sock;
this.epl = epl;
this._socket = socket;
this._epl = epl;
this.secure = secure;
this.cert = cert;
_cryptoProvider = cryptoProvider;
@@ -63,11 +64,11 @@ namespace SocketHttpListener.Net
{
if (secure == false)
{
stream = _streamFactory.CreateNetworkStream(sock, false);
_stream = _streamFactory.CreateNetworkStream(_socket, false);
}
else
{
//ssl_stream = epl.Listener.CreateSslStream(new NetworkStream(sock, false), false, (t, c, ch, e) =>
//ssl_stream = _epl.Listener.CreateSslStream(new NetworkStream(_socket, false), false, (t, c, ch, e) =>
//{
// if (c == null)
// return true;
@@ -78,11 +79,11 @@ namespace SocketHttpListener.Net
// client_cert_errors = new int[] { (int)e };
// return true;
//});
//stream = ssl_stream.AuthenticatedStream;
//_stream = ssl_stream.AuthenticatedStream;
ssl_stream = _streamFactory.CreateSslStream(_streamFactory.CreateNetworkStream(sock, false), false);
ssl_stream = _streamFactory.CreateSslStream(_streamFactory.CreateNetworkStream(_socket, false), false);
await _streamFactory.AuthenticateSslStreamAsServer(ssl_stream, cert).ConfigureAwait(false);
stream = ssl_stream;
_stream = ssl_stream;
}
Init();
}
@@ -100,7 +101,7 @@ namespace SocketHttpListener.Net
{
get
{
return stream;
return _stream;
}
}
@@ -111,32 +112,26 @@ namespace SocketHttpListener.Net
void Init()
{
if (ssl_stream != null)
{
//ssl_stream.AuthenticateAsServer(client_cert, true, (SslProtocols)ServicePointManager.SecurityProtocol, false);
//_streamFactory.AuthenticateSslStreamAsServer(ssl_stream, cert);
}
context_bound = false;
i_stream = null;
o_stream = null;
prefix = null;
chunked = false;
ms = _memoryStreamFactory.CreateNew();
position = 0;
input_state = InputState.RequestLine;
line_state = LineState.None;
context = new HttpListenerContext(this, _logger, _cryptoProvider, _memoryStreamFactory, _textEncoding, _fileSystem);
_contextBound = false;
_requestStream = null;
_responseStream = null;
_prefix = null;
_chunked = false;
_memoryStream = new MemoryStream();
_position = 0;
_inputState = InputState.RequestLine;
_lineState = LineState.None;
_context = new HttpListenerContext(this, _logger, _cryptoProvider, _memoryStreamFactory, _textEncoding, _fileSystem);
}
public bool IsClosed
{
get { return (sock == null); }
get { return (_socket == null); }
}
public int Reuses
{
get { return reuses; }
get { return _reuses; }
}
public IpEndPointInfo LocalEndPoint
@@ -146,14 +141,14 @@ namespace SocketHttpListener.Net
if (local_ep != null)
return local_ep;
local_ep = (IpEndPointInfo)sock.LocalEndPoint;
local_ep = (IpEndPointInfo)_socket.LocalEndPoint;
return local_ep;
}
}
public IpEndPointInfo RemoteEndPoint
{
get { return (IpEndPointInfo)sock.RemoteEndPoint; }
get { return (IpEndPointInfo)_socket.RemoteEndPoint; }
}
public bool IsSecure
@@ -163,187 +158,186 @@ namespace SocketHttpListener.Net
public ListenerPrefix Prefix
{
get { return prefix; }
set { prefix = value; }
get { return _prefix; }
set { _prefix = value; }
}
public async Task BeginReadRequest()
public void BeginReadRequest()
{
if (buffer == null)
buffer = new byte[BufferSize];
if (_buffer == null)
_buffer = new byte[BufferSize];
try
{
//if (reuses == 1)
// s_timeout = 15000;
var nRead = await stream.ReadAsync(buffer, 0, BufferSize).ConfigureAwait(false);
OnReadInternal(nRead);
if (_reuses == 1)
_timeout = 15000;
//_timer.Change(_timeout, Timeout.Infinite);
_stream.BeginRead(_buffer, 0, BufferSize, s_onreadCallback, this);
}
catch (Exception ex)
catch
{
OnReadInternalException(ms, ex);
//_timer.Change(Timeout.Infinite, Timeout.Infinite);
CloseSocket();
Unbind();
}
}
public HttpRequestStream GetRequestStream(bool chunked, long contentlength)
{
if (i_stream == null)
if (_requestStream == null)
{
byte[] buffer;
_memoryStreamFactory.TryGetBuffer(ms, out buffer);
int length = (int)ms.Length;
ms = null;
byte[] buffer = _memoryStream.GetBuffer();
int length = (int)_memoryStream.Length;
_memoryStream = null;
if (chunked)
{
this.chunked = true;
//context.Response.SendChunked = true;
i_stream = new ChunkedInputStream(context, stream, buffer, position, length - position);
_chunked = true;
//_context.Response.SendChunked = true;
_requestStream = new ChunkedInputStream(_context, _stream, buffer, _position, length - _position);
}
else
{
i_stream = new HttpRequestStream(stream, buffer, position, length - position, contentlength);
_requestStream = new HttpRequestStream(_stream, buffer, _position, length - _position, contentlength);
}
}
return i_stream;
return _requestStream;
}
public Stream GetResponseStream(bool isExpect100Continue = false)
{
// TODO: can we get this stream before reading the input?
if (o_stream == null)
// TODO: can we get this _stream before reading the input?
if (_responseStream == null)
{
//context.Response.DetermineIfChunked();
var supportsDirectSocketAccess = !_context.Response.SendChunked && !isExpect100Continue && !secure;
var supportsDirectSocketAccess = !context.Response.SendChunked && !isExpect100Continue && !secure;
//o_stream = new ResponseStream(stream, context.Response, _memoryStreamFactory, _textEncoding, _fileSystem, sock, supportsDirectSocketAccess, _logger, _environment);
o_stream = new HttpResponseStream(stream, context.Response, false, _memoryStreamFactory, sock, supportsDirectSocketAccess, _environment, _fileSystem, _logger);
_responseStream = new HttpResponseStream(_stream, _context.Response, false, _memoryStreamFactory, _socket, supportsDirectSocketAccess, _environment, _fileSystem, _logger);
}
return o_stream;
return _responseStream;
}
void OnReadInternal(int nread)
private static void OnRead(IAsyncResult ares)
{
ms.Write(buffer, 0, nread);
if (ms.Length > 32768)
HttpConnection cnc = (HttpConnection)ares.AsyncState;
cnc.OnReadInternal(ares);
}
private void OnReadInternal(IAsyncResult ares)
{
//_timer.Change(Timeout.Infinite, Timeout.Infinite);
int nread = -1;
try
{
SendError("Bad request", 400);
Close(true);
nread = _stream.EndRead(ares);
_memoryStream.Write(_buffer, 0, nread);
if (_memoryStream.Length > 32768)
{
SendError("Bad Request", 400);
Close(true);
return;
}
}
catch
{
if (_memoryStream != null && _memoryStream.Length > 0)
SendError();
if (_socket != null)
{
CloseSocket();
Unbind();
}
return;
}
if (nread == 0)
{
//if (ms.Length > 0)
// SendError (); // Why bother?
CloseSocket();
Unbind();
return;
}
if (ProcessInput(ms))
if (ProcessInput(_memoryStream))
{
if (!context.HaveError)
context.Request.FinishInitialization();
if (!_context.HaveError)
_context.Request.FinishInitialization();
if (context.HaveError)
if (_context.HaveError)
{
SendError();
Close(true);
return;
}
if (!epl.BindContext(context))
if (!_epl.BindContext(_context))
{
SendError("Invalid host", 400);
Close(true);
return;
}
HttpListener listener = epl.Listener;
if (last_listener != listener)
HttpListener listener = _epl.Listener;
if (_lastListener != listener)
{
RemoveConnection();
listener.AddConnection(this);
last_listener = listener;
_lastListener = listener;
}
context_bound = true;
listener.RegisterContext(context);
_contextBound = true;
listener.RegisterContext(_context);
return;
}
BeginReadRequest();
_stream.BeginRead(_buffer, 0, BufferSize, s_onreadCallback, this);
}
private void OnReadInternalException(MemoryStream ms, Exception ex)
private void RemoveConnection()
{
//_logger.ErrorException("Error in HttpConnection.OnReadInternal", ex);
if (ms != null && ms.Length > 0)
SendError();
if (sock != null)
{
CloseSocket();
Unbind();
}
}
void RemoveConnection()
{
if (last_listener == null)
epl.RemoveConnection(this);
if (_lastListener == null)
_epl.RemoveConnection(this);
else
last_listener.RemoveConnection(this);
_lastListener.RemoveConnection(this);
}
enum InputState
private enum InputState
{
RequestLine,
Headers
}
enum LineState
private enum LineState
{
None,
CR,
LF
}
InputState input_state = InputState.RequestLine;
LineState line_state = LineState.None;
int position;
InputState _inputState = InputState.RequestLine;
LineState _lineState = LineState.None;
int _position;
// true -> done processing
// false -> need more input
bool ProcessInput(MemoryStream ms)
private bool ProcessInput(MemoryStream ms)
{
byte[] buffer;
_memoryStreamFactory.TryGetBuffer(ms, out buffer);
byte[] buffer = ms.GetBuffer();
int len = (int)ms.Length;
int used = 0;
string line;
while (true)
{
if (context.HaveError)
if (_context.HaveError)
return true;
if (position >= len)
if (_position >= len)
break;
try
{
line = ReadLine(buffer, position, len - position, ref used);
position += used;
line = ReadLine(buffer, _position, len - _position, ref used);
_position += used;
}
catch
{
context.ErrorMessage = "Bad request";
context.ErrorStatus = 400;
_context.ErrorMessage = "Bad request";
_context.ErrorStatus = 400;
return true;
}
@@ -352,28 +346,28 @@ namespace SocketHttpListener.Net
if (line == "")
{
if (input_state == InputState.RequestLine)
if (_inputState == InputState.RequestLine)
continue;
current_line = null;
_currentLine = null;
ms = null;
return true;
}
if (input_state == InputState.RequestLine)
if (_inputState == InputState.RequestLine)
{
context.Request.SetRequestLine(line);
input_state = InputState.Headers;
_context.Request.SetRequestLine(line);
_inputState = InputState.Headers;
}
else
{
try
{
context.Request.AddHeader(line);
_context.Request.AddHeader(line);
}
catch (Exception e)
{
context.ErrorMessage = e.Message;
context.ErrorStatus = 400;
_context.ErrorMessage = e.Message;
_context.ErrorStatus = 400;
return true;
}
}
@@ -382,42 +376,41 @@ namespace SocketHttpListener.Net
if (used == len)
{
ms.SetLength(0);
position = 0;
_position = 0;
}
return false;
}
string ReadLine(byte[] buffer, int offset, int len, ref int used)
private string ReadLine(byte[] buffer, int offset, int len, ref int used)
{
if (current_line == null)
current_line = new StringBuilder(128);
if (_currentLine == null)
_currentLine = new StringBuilder(128);
int last = offset + len;
used = 0;
for (int i = offset; i < last && line_state != LineState.LF; i++)
for (int i = offset; i < last && _lineState != LineState.LF; i++)
{
used++;
byte b = buffer[i];
if (b == 13)
{
line_state = LineState.CR;
_lineState = LineState.CR;
}
else if (b == 10)
{
line_state = LineState.LF;
_lineState = LineState.LF;
}
else
{
current_line.Append((char)b);
_currentLine.Append((char)b);
}
}
string result = null;
if (line_state == LineState.LF)
if (_lineState == LineState.LF)
{
line_state = LineState.None;
result = current_line.ToString();
current_line.Length = 0;
_lineState = LineState.None;
result = _currentLine.ToString();
_currentLine.Length = 0;
}
return result;
@@ -427,20 +420,18 @@ namespace SocketHttpListener.Net
{
try
{
HttpListenerResponse response = context.Response;
HttpListenerResponse response = _context.Response;
response.StatusCode = status;
response.ContentType = "text/html";
string description = HttpListenerResponse.GetStatusDescription(status);
string str;
if (msg != null)
str = String.Format("<h1>{0} ({1})</h1>", description, msg);
str = string.Format("<h1>{0} ({1})</h1>", description, msg);
else
str = String.Format("<h1>{0}</h1>", description);
str = string.Format("<h1>{0}</h1>", description);
byte[] error = context.Response.ContentEncoding.GetBytes(str);
response.ContentLength64 = error.Length;
response.OutputStream.Write(error, 0, (int)error.Length);
response.Close();
byte[] error = Encoding.Default.GetBytes(str);
response.Close(error, false);
}
catch
{
@@ -450,15 +441,15 @@ namespace SocketHttpListener.Net
public void SendError()
{
SendError(context.ErrorMessage, context.ErrorStatus);
SendError(_context.ErrorMessage, _context.ErrorStatus);
}
void Unbind()
private void Unbind()
{
if (context_bound)
if (_contextBound)
{
epl.UnbindContext(context);
context_bound = false;
_epl.UnbindContext(_context);
_contextBound = false;
}
}
@@ -469,64 +460,60 @@ namespace SocketHttpListener.Net
private void CloseSocket()
{
if (sock == null)
if (_socket == null)
return;
try
{
sock.Close();
}
catch
{
_socket.Close();
}
catch { }
finally
{
sock = null;
_socket = null;
}
RemoveConnection();
}
internal void Close(bool force_close)
internal void Close(bool force)
{
if (sock != null)
if (_socket != null)
{
if (!context.Request.IsWebSocketRequest || force_close)
{
Stream st = GetResponseStream();
if (st != null)
{
st.Dispose();
}
Stream st = GetResponseStream();
if (st != null)
st.Close();
o_stream = null;
}
_responseStream = null;
}
if (sock != null)
if (_socket != null)
{
force_close |= !context.Request.KeepAlive;
if (!force_close)
force_close = (string.Equals(context.Response.Headers["connection"], "close", StringComparison.OrdinalIgnoreCase));
/*
if (!force_close) {
// bool conn_close = (status_code == 400 || status_code == 408 || status_code == 411 ||
// status_code == 413 || status_code == 414 || status_code == 500 ||
// status_code == 503);
force_close |= (context.Request.ProtocolVersion <= HttpVersion.Version10);
}
*/
force |= !_context.Request.KeepAlive;
if (!force)
force = (string.Equals(_context.Response.Headers["connection"], "close", StringComparison.OrdinalIgnoreCase));
if (!force_close && context.Request.FlushInput())
if (!force && _context.Request.FlushInput())
{
reuses++;
if (_chunked && _context.Response.ForceCloseChunked == false)
{
// Don't close. Keep working.
_reuses++;
Unbind();
Init();
BeginReadRequest();
return;
}
_reuses++;
Unbind();
Init();
BeginReadRequest();
return;
}
IAcceptSocket s = sock;
sock = null;
IAcceptSocket s = _socket;
_socket = null;
try
{
if (s != null)

View File

@@ -53,6 +53,11 @@ namespace SocketHttpListener.Net
}
}
public bool ForceCloseChunked
{
get { return false; }
}
public Encoding ContentEncoding
{
get
@@ -335,6 +340,48 @@ namespace SocketHttpListener.Net
context.Connection.Close(force);
}
public void Close(byte[] responseEntity, bool willBlock)
{
//CheckDisposed();
if (responseEntity == null)
{
throw new ArgumentNullException(nameof(responseEntity));
}
//if (_boundaryType != BoundaryType.Chunked)
{
ContentLength64 = responseEntity.Length;
}
if (willBlock)
{
try
{
OutputStream.Write(responseEntity, 0, responseEntity.Length);
}
finally
{
Close(false);
}
}
else
{
OutputStream.BeginWrite(responseEntity, 0, responseEntity.Length, iar =>
{
var thisRef = (HttpListenerResponse)iar.AsyncState;
try
{
thisRef.OutputStream.EndWrite(iar);
}
finally
{
thisRef.Close(false);
}
}, this);
}
}
public void Close()
{
if (disposed)

View File

@@ -325,10 +325,7 @@ namespace SocketHttpListener.Net
}
}
private bool EnableSendFileWithSocket
{
get { return false; }
}
private bool EnableSendFileWithSocket = false;
public Task TransmitFile(string path, long offset, long count, FileShareMode fileShareMode, CancellationToken cancellationToken)
{