aboutsummaryrefslogtreecommitdiff
path: root/Emby.Server.Implementations/WebSockets/WebSocketManager.cs
blob: 04c73ecea743faa3b7ddcb5853faa4ba8cd47a93 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using MediaBrowser.Controller.Net;
using MediaBrowser.Model.Net;
using MediaBrowser.Model.Serialization;
using Microsoft.Extensions.Logging;
using UtfUnknown;

namespace Emby.Server.Implementations.WebSockets
{
    public class WebSocketManager
    {
        private readonly IWebSocketHandler[] _webSocketHandlers;
        private readonly IJsonSerializer _jsonSerializer;
        private readonly ILogger<WebSocketManager> _logger;
        private const int BufferSize = 4096;

        public WebSocketManager(IWebSocketHandler[] webSocketHandlers, IJsonSerializer jsonSerializer, ILogger<WebSocketManager> logger)
        {
            _webSocketHandlers = webSocketHandlers;
            _jsonSerializer = jsonSerializer;
            _logger = logger;
        }

        public async Task OnWebSocketConnected(WebSocket webSocket)
        {
            var taskCompletionSource = new TaskCompletionSource<bool>();
            var cancellationToken = new CancellationTokenSource().Token;
            WebSocketReceiveResult result;
            var message = new List<byte>();

            // Keep listening for incoming messages, otherwise the socket closes automatically
            do
            {
                var buffer = WebSocket.CreateServerBuffer(BufferSize);
                result = await webSocket.ReceiveAsync(buffer, cancellationToken);
                message.AddRange(buffer.Array.Take(result.Count));

                if (result.EndOfMessage)
                {
                    await ProcessMessage(message.ToArray(), taskCompletionSource);
                    message.Clear();
                }
            } while (!taskCompletionSource.Task.IsCompleted &&
                     webSocket.State == WebSocketState.Open &&
                     result.MessageType != WebSocketMessageType.Close);

            if (webSocket.State == WebSocketState.Open)
            {
                await webSocket.CloseAsync(result.CloseStatus ?? WebSocketCloseStatus.NormalClosure,
                    result.CloseStatusDescription, cancellationToken);
            }
        }

        private async Task ProcessMessage(byte[] messageBytes, TaskCompletionSource<bool> taskCompletionSource)
        {
            var charset = CharsetDetector.DetectFromBytes(messageBytes).Detected?.EncodingName;
            var message = string.Equals(charset, "utf-8", StringComparison.OrdinalIgnoreCase)
                ? Encoding.UTF8.GetString(messageBytes, 0, messageBytes.Length)
                : Encoding.ASCII.GetString(messageBytes, 0, messageBytes.Length);

            // All messages are expected to be valid JSON objects
            if (!message.StartsWith("{", StringComparison.OrdinalIgnoreCase))
            {
                _logger.LogDebug("Received web socket message that is not a json structure: {Message}", message);
                return;
            }

            try
            {
                var info = _jsonSerializer.DeserializeFromString<WebSocketMessage<object>>(message);

                _logger.LogDebug("Websocket message received: {0}", info.MessageType);

                var tasks = _webSocketHandlers.Select(handler => Task.Run(() =>
                {
                    try
                    {
                        handler.ProcessMessage(info, taskCompletionSource).ConfigureAwait(false);
                    }
                    catch (Exception ex)
                    {
                        _logger.LogError(ex, "{HandlerType} failed processing WebSocket message {MessageType}",
                            handler.GetType().Name, info.MessageType ?? string.Empty);
                    }
                }));

                await Task.WhenAll(tasks);
            }
            catch (Exception ex)
            {
                _logger.LogError(ex, "Error processing web socket message");
            }
        }
    }
}