diff options
7 files changed, 32 insertions, 43 deletions
diff --git a/Emby.Server.Implementations/HttpServer/WebSocketConnection.cs b/Emby.Server.Implementations/HttpServer/WebSocketConnection.cs index b3bd3421a..b87f1bc22 100644 --- a/Emby.Server.Implementations/HttpServer/WebSocketConnection.cs +++ b/Emby.Server.Implementations/HttpServer/WebSocketConnection.cs @@ -42,17 +42,14 @@ namespace Emby.Server.Implementations.HttpServer /// <param name="logger">The logger.</param> /// <param name="socket">The socket.</param> /// <param name="remoteEndPoint">The remote end point.</param> - /// <param name="query">The query.</param> public WebSocketConnection( ILogger<WebSocketConnection> logger, WebSocket socket, - IPAddress? remoteEndPoint, - IQueryCollection query) + IPAddress? remoteEndPoint) { _logger = logger; _socket = socket; RemoteEndPoint = remoteEndPoint; - QueryString = query; _jsonOptions = JsonDefaults.Options; LastActivityDate = DateTime.Now; @@ -82,12 +79,6 @@ namespace Emby.Server.Implementations.HttpServer public DateTime LastKeepAliveDate { get; set; } /// <summary> - /// Gets the query string. - /// </summary> - /// <value>The query string.</value> - public IQueryCollection QueryString { get; } - - /// <summary> /// Gets the state. /// </summary> /// <value>The state.</value> diff --git a/Emby.Server.Implementations/HttpServer/WebSocketManager.cs b/Emby.Server.Implementations/HttpServer/WebSocketManager.cs index e99876dce..4f7d1c40a 100644 --- a/Emby.Server.Implementations/HttpServer/WebSocketManager.cs +++ b/Emby.Server.Implementations/HttpServer/WebSocketManager.cs @@ -7,6 +7,7 @@ using System.Collections.Generic; using System.Linq; using System.Net.WebSockets; using System.Threading.Tasks; +using MediaBrowser.Common.Extensions; using MediaBrowser.Controller.Net; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; @@ -50,8 +51,7 @@ namespace Emby.Server.Implementations.HttpServer using var connection = new WebSocketConnection( _loggerFactory.CreateLogger<WebSocketConnection>(), webSocket, - context.Connection.RemoteIpAddress, - context.Request.Query) + context.GetNormalizedRemoteIp()) { OnReceive = ProcessWebSocketMessageReceived }; @@ -59,7 +59,7 @@ namespace Emby.Server.Implementations.HttpServer var tasks = new Task[_webSocketListeners.Length]; for (var i = 0; i < _webSocketListeners.Length; ++i) { - tasks[i] = _webSocketListeners[i].ProcessWebSocketConnectedAsync(connection); + tasks[i] = _webSocketListeners[i].ProcessWebSocketConnectedAsync(connection, context); } await Task.WhenAll(tasks).ConfigureAwait(false); diff --git a/Emby.Server.Implementations/Session/SessionWebSocketListener.cs b/Emby.Server.Implementations/Session/SessionWebSocketListener.cs index 2a14a8c7b..a085ee546 100644 --- a/Emby.Server.Implementations/Session/SessionWebSocketListener.cs +++ b/Emby.Server.Implementations/Session/SessionWebSocketListener.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Net.WebSockets; using System.Threading; using System.Threading.Tasks; +using MediaBrowser.Common.Extensions; using MediaBrowser.Controller.Net; using MediaBrowser.Controller.Session; using MediaBrowser.Model.Net; @@ -50,16 +51,10 @@ namespace Emby.Server.Implementations.Session /// </summary> private readonly object _webSocketsLock = new object(); - /// <summary> - /// The _session manager. - /// </summary> private readonly ISessionManager _sessionManager; - - /// <summary> - /// The _logger. - /// </summary> private readonly ILogger<SessionWebSocketListener> _logger; private readonly ILoggerFactory _loggerFactory; + private readonly IAuthorizationContext _authorizationContext; /// <summary> /// The KeepAlive cancellation token. @@ -72,14 +67,17 @@ namespace Emby.Server.Implementations.Session /// <param name="logger">The logger.</param> /// <param name="sessionManager">The session manager.</param> /// <param name="loggerFactory">The logger factory.</param> + /// <param name="authorizationContext">The authorization context.</param> public SessionWebSocketListener( ILogger<SessionWebSocketListener> logger, ISessionManager sessionManager, - ILoggerFactory loggerFactory) + ILoggerFactory loggerFactory, + IAuthorizationContext authorizationContext) { _logger = logger; _sessionManager = sessionManager; _loggerFactory = loggerFactory; + _authorizationContext = authorizationContext; } /// <inheritdoc /> @@ -97,9 +95,9 @@ namespace Emby.Server.Implementations.Session => Task.CompletedTask; /// <inheritdoc /> - public async Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection) + public async Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection, HttpContext httpContext) { - var session = await GetSession(connection.QueryString, connection.RemoteEndPoint.ToString()).ConfigureAwait(false); + var session = await GetSession(httpContext, connection.RemoteEndPoint?.ToString()).ConfigureAwait(false); if (session != null) { EnsureController(session, connection); @@ -107,25 +105,28 @@ namespace Emby.Server.Implementations.Session } else { - _logger.LogWarning("Unable to determine session based on query string: {0}", connection.QueryString); + _logger.LogWarning("Unable to determine session based on query string: {0}", httpContext.Request.QueryString); } } - private Task<SessionInfo> GetSession(IQueryCollection queryString, string remoteEndpoint) + private async Task<SessionInfo> GetSession(HttpContext httpContext, string remoteEndpoint) { - if (queryString == null) + var authorizationInfo = await _authorizationContext.GetAuthorizationInfo(httpContext) + .ConfigureAwait(false); + + if (!authorizationInfo.IsAuthenticated) { return null; } - var token = queryString["api_key"]; - if (string.IsNullOrWhiteSpace(token)) + var deviceId = authorizationInfo.DeviceId; + if (httpContext.Request.Query.TryGetValue("deviceId", out var queryDeviceId)) { - return null; + deviceId = queryDeviceId; } - var deviceId = queryString["deviceId"]; - return _sessionManager.GetSessionByAuthenticationToken(token, deviceId, remoteEndpoint); + return await _sessionManager.GetSessionByAuthenticationToken(authorizationInfo.Token, deviceId, remoteEndpoint) + .ConfigureAwait(false); } private void EnsureController(SessionInfo session, IWebSocketConnection connection) diff --git a/MediaBrowser.Controller/Net/BasePeriodicWebSocketListener.cs b/MediaBrowser.Controller/Net/BasePeriodicWebSocketListener.cs index 0813a8e7d..eadc09fd4 100644 --- a/MediaBrowser.Controller/Net/BasePeriodicWebSocketListener.cs +++ b/MediaBrowser.Controller/Net/BasePeriodicWebSocketListener.cs @@ -11,6 +11,7 @@ using System.Threading; using System.Threading.Tasks; using MediaBrowser.Model.Net; using MediaBrowser.Model.Session; +using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; namespace MediaBrowser.Controller.Net @@ -95,7 +96,7 @@ namespace MediaBrowser.Controller.Net } /// <inheritdoc /> - public Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection) => Task.CompletedTask; + public Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection, HttpContext httpContext) => Task.CompletedTask; /// <summary> /// Starts sending messages over a web socket. diff --git a/MediaBrowser.Controller/Net/IWebSocketConnection.cs b/MediaBrowser.Controller/Net/IWebSocketConnection.cs index c8c5caf80..2c6483ae2 100644 --- a/MediaBrowser.Controller/Net/IWebSocketConnection.cs +++ b/MediaBrowser.Controller/Net/IWebSocketConnection.cs @@ -30,12 +30,6 @@ namespace MediaBrowser.Controller.Net DateTime LastKeepAliveDate { get; set; } /// <summary> - /// Gets the query string. - /// </summary> - /// <value>The query string.</value> - IQueryCollection QueryString { get; } - - /// <summary> /// Gets or sets the receive action. /// </summary> /// <value>The receive action.</value> diff --git a/MediaBrowser.Controller/Net/IWebSocketListener.cs b/MediaBrowser.Controller/Net/IWebSocketListener.cs index f1a75d518..672bb8cbf 100644 --- a/MediaBrowser.Controller/Net/IWebSocketListener.cs +++ b/MediaBrowser.Controller/Net/IWebSocketListener.cs @@ -1,4 +1,5 @@ using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; namespace MediaBrowser.Controller.Net { @@ -18,7 +19,8 @@ namespace MediaBrowser.Controller.Net /// Processes a new web socket connection. /// </summary> /// <param name="connection">An instance of the <see cref="IWebSocketConnection"/> interface.</param> + /// <param name="httpContext">The current http context.</param> /// <returns>Task.</returns> - Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection); + Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection, HttpContext httpContext); } } diff --git a/tests/Jellyfin.Server.Implementations.Tests/HttpServer/WebSocketConnectionTests.cs b/tests/Jellyfin.Server.Implementations.Tests/HttpServer/WebSocketConnectionTests.cs index 1ce2096ea..ef8f7cd90 100644 --- a/tests/Jellyfin.Server.Implementations.Tests/HttpServer/WebSocketConnectionTests.cs +++ b/tests/Jellyfin.Server.Implementations.Tests/HttpServer/WebSocketConnectionTests.cs @@ -13,7 +13,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer [Fact] public void DeserializeWebSocketMessage_SingleSegment_Success() { - var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!); + var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!); var bytes = File.ReadAllBytes("Test Data/HttpServer/ForceKeepAlive.json"); con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed); Assert.Equal(109, bytesConsumed); @@ -23,7 +23,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer public void DeserializeWebSocketMessage_MultipleSegments_Success() { const int SplitPos = 64; - var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!); + var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!); var bytes = File.ReadAllBytes("Test Data/HttpServer/ForceKeepAlive.json"); var seg1 = new BufferSegment(new Memory<byte>(bytes, 0, SplitPos)); var seg2 = seg1.Append(new Memory<byte>(bytes, SplitPos, bytes.Length - SplitPos)); @@ -34,7 +34,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer [Fact] public void DeserializeWebSocketMessage_ValidPartial_Success() { - var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!); + var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!); var bytes = File.ReadAllBytes("Test Data/HttpServer/ValidPartial.json"); con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed); Assert.Equal(109, bytesConsumed); @@ -43,7 +43,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer [Fact] public void DeserializeWebSocketMessage_Partial_ThrowJsonException() { - var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!); + var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!); var bytes = File.ReadAllBytes("Test Data/HttpServer/Partial.json"); Assert.Throws<JsonException>(() => con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed)); } |
