aboutsummaryrefslogtreecommitdiff
path: root/SocketHttpListener/Net/HttpListenerContext.cs
blob: 58d769f22a05a28678e9cce393907d623c579254 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
using System;
using System.Net;
using System.Security.Principal;
using MediaBrowser.Model.Cryptography;
using MediaBrowser.Model.IO;
using MediaBrowser.Model.Logging;
using MediaBrowser.Model.Text;
using SocketHttpListener.Net.WebSockets;
using SocketHttpListener.Primitives;

namespace SocketHttpListener.Net
{
    public sealed class HttpListenerContext
    {
        HttpListenerRequest request;
        HttpListenerResponse response;
        IPrincipal user;
        HttpConnection cnc;
        string error;
        int err_status = 400;
        private readonly ICryptoProvider _cryptoProvider;
        private readonly IMemoryStreamFactory _memoryStreamFactory;
        private readonly ITextEncoding _textEncoding;

        internal HttpListenerContext(HttpConnection cnc, ILogger logger, ICryptoProvider cryptoProvider, IMemoryStreamFactory memoryStreamFactory, ITextEncoding textEncoding, IFileSystem fileSystem)
        {
            this.cnc = cnc;
            _cryptoProvider = cryptoProvider;
            _memoryStreamFactory = memoryStreamFactory;
            _textEncoding = textEncoding;
            request = new HttpListenerRequest(this, _textEncoding);
            response = new HttpListenerResponse(this, logger, _textEncoding, fileSystem);
        }

        internal int ErrorStatus
        {
            get { return err_status; }
            set { err_status = value; }
        }

        internal string ErrorMessage
        {
            get { return error; }
            set { error = value; }
        }

        internal bool HaveError
        {
            get { return (error != null); }
        }

        internal HttpConnection Connection
        {
            get { return cnc; }
        }

        public HttpListenerRequest Request
        {
            get { return request; }
        }

        public HttpListenerResponse Response
        {
            get { return response; }
        }

        public IPrincipal User
        {
            get { return user; }
        }

        internal void ParseAuthentication(AuthenticationSchemes expectedSchemes)
        {
            if (expectedSchemes == AuthenticationSchemes.Anonymous)
                return;

            // TODO: Handle NTLM/Digest modes
            string header = request.Headers["Authorization"];
            if (header == null || header.Length < 2)
                return;

            string[] authenticationData = header.Split(new char[] { ' ' }, 2);
            if (string.Equals(authenticationData[0], "basic", StringComparison.OrdinalIgnoreCase))
            {
                user = ParseBasicAuthentication(authenticationData[1]);
            }
            // TODO: throw if malformed -> 400 bad request
        }

        internal IPrincipal ParseBasicAuthentication(string authData)
        {
            try
            {
                // Basic AUTH Data is a formatted Base64 String
                //string domain = null;
                string user = null;
                string password = null;
                int pos = -1;
                var authDataBytes = Convert.FromBase64String(authData);
                string authString = _textEncoding.GetDefaultEncoding().GetString(authDataBytes, 0, authDataBytes.Length);

                // The format is DOMAIN\username:password
                // Domain is optional

                pos = authString.IndexOf(':');

                // parse the password off the end
                password = authString.Substring(pos + 1);

                // discard the password
                authString = authString.Substring(0, pos);

                // check if there is a domain
                pos = authString.IndexOf('\\');

                if (pos > 0)
                {
                    //domain = authString.Substring (0, pos);
                    user = authString.Substring(pos);
                }
                else
                {
                    user = authString;
                }

                HttpListenerBasicIdentity identity = new HttpListenerBasicIdentity(user, password);
                // TODO: What are the roles MS sets
                return new GenericPrincipal(identity, new string[0]);
            }
            catch (Exception)
            {
                // Invalid auth data is swallowed silently
                return null;
            }
        }

        public HttpListenerWebSocketContext AcceptWebSocket(string protocol)
        {
            if (protocol != null)
            {
                if (protocol.Length == 0)
                    throw new ArgumentException("An empty string.", "protocol");

                if (!protocol.IsToken())
                    throw new ArgumentException("Contains an invalid character.", "protocol");
            }

            return new HttpListenerWebSocketContext(this, protocol, _cryptoProvider, _memoryStreamFactory);
        }
    }

    public class GenericPrincipal : IPrincipal
    {
        private IIdentity m_identity;
        private string[] m_roles;

        public GenericPrincipal(IIdentity identity, string[] roles)
        {
            if (identity == null)
                throw new ArgumentNullException("identity");

            m_identity = identity;
            if (roles != null)
            {
                m_roles = new string[roles.Length];
                for (int i = 0; i < roles.Length; ++i)
                {
                    m_roles[i] = roles[i];
                }
            }
            else
            {
                m_roles = null;
            }
        }

        public virtual IIdentity Identity
        {
            get
            {
                return m_identity;
            }
        }

        public virtual bool IsInRole(string role)
        {
            if (role == null || m_roles == null)
                return false;

            for (int i = 0; i < m_roles.Length; ++i)
            {
                if (m_roles[i] != null && String.Compare(m_roles[i], role, StringComparison.OrdinalIgnoreCase) == 0)
                    return true;
            }
            return false;
        }
    }
}