diff --git a/src/Ryujinx.HLE/HOS/Services/ServerBase.cs b/src/Ryujinx.HLE/HOS/Services/ServerBase.cs index b994679a..ff6df8a3 100644 --- a/src/Ryujinx.HLE/HOS/Services/ServerBase.cs +++ b/src/Ryujinx.HLE/HOS/Services/ServerBase.cs @@ -1,4 +1,5 @@ using Ryujinx.Common; +using Ryujinx.Common.Logging; using Ryujinx.Common.Memory; using Ryujinx.HLE.HOS.Ipc; using Ryujinx.HLE.HOS.Kernel; @@ -32,13 +33,14 @@ namespace Ryujinx.HLE.HOS.Services 0x01007FFF }; - private readonly object _handleLock = new(); + // The amount of time Dispose() will wait to Join() the thread executing the ServerLoop() + private static readonly TimeSpan ThreadJoinTimeout = TimeSpan.FromSeconds(3); private readonly KernelContext _context; private KProcess _selfProcess; + private KThread _selfThread; - private readonly List _sessionHandles = new List(); - private readonly List _portHandles = new List(); + private readonly ReaderWriterLockSlim _handleLock = new ReaderWriterLockSlim(); private readonly Dictionary _sessions = new Dictionary(); private readonly Dictionary> _ports = new Dictionary>(); @@ -48,6 +50,8 @@ namespace Ryujinx.HLE.HOS.Services private readonly MemoryStream _responseDataStream; private readonly BinaryWriter _responseDataWriter; + private int _isDisposed = 0; + public ManualResetEvent InitDone { get; } public string Name { get; } public Func SmObjectFactory { get; } @@ -79,11 +83,20 @@ namespace Ryujinx.HLE.HOS.Services private void AddPort(int serverPortHandle, Func objectFactory) { - lock (_handleLock) + bool lockTaken = false; + try { - _portHandles.Add(serverPortHandle); + lockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite); + + _ports.Add(serverPortHandle, objectFactory); + } + finally + { + if (lockTaken) + { + _handleLock.ExitWriteLock(); + } } - _ports.Add(serverPortHandle, objectFactory); } public void AddSessionObj(KServerSession serverSession, IpcService obj) @@ -92,16 +105,62 @@ namespace Ryujinx.HLE.HOS.Services InitDone.WaitOne(); _selfProcess.HandleTable.GenerateHandle(serverSession, out int serverSessionHandle); + AddSessionObj(serverSessionHandle, obj); } public void AddSessionObj(int serverSessionHandle, IpcService obj) { - lock (_handleLock) + bool lockTaken = false; + try { - _sessionHandles.Add(serverSessionHandle); + lockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite); + + _sessions.Add(serverSessionHandle, obj); + } + finally + { + if (lockTaken) + { + _handleLock.ExitWriteLock(); + } + } + } + + private IpcService GetSessionObj(int serverSessionHandle) + { + bool lockTaken = false; + try + { + lockTaken = _handleLock.TryEnterReadLock(Timeout.Infinite); + + return _sessions[serverSessionHandle]; + } + finally + { + if (lockTaken) + { + _handleLock.ExitReadLock(); + } + } + } + + private bool RemoveSessionObj(int serverSessionHandle, out IpcService obj) + { + bool lockTaken = false; + try + { + lockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite); + + return _sessions.Remove(serverSessionHandle, out obj); + } + finally + { + if (lockTaken) + { + _handleLock.ExitWriteLock(); + } } - _sessions.Add(serverSessionHandle, obj); } private void Main() @@ -112,6 +171,7 @@ namespace Ryujinx.HLE.HOS.Services private void ServerLoop() { _selfProcess = KernelStatic.GetCurrentProcess(); + _selfThread = KernelStatic.GetCurrentThread(); if (SmObjectFactory != null) { @@ -122,8 +182,7 @@ namespace Ryujinx.HLE.HOS.Services InitDone.Set(); - KThread thread = KernelStatic.GetCurrentThread(); - ulong messagePtr = thread.TlsAddress; + ulong messagePtr = _selfThread.TlsAddress; _context.Syscall.SetHeapSize(out ulong heapAddr, 0x200000); _selfProcess.CpuMemory.Write(messagePtr + 0x0, 0); @@ -134,27 +193,39 @@ namespace Ryujinx.HLE.HOS.Services while (true) { - int handleCount; int portHandleCount; + int handleCount; int[] handles; - lock (_handleLock) + bool handleLockTaken = false; + try { - portHandleCount = _portHandles.Count; - handleCount = portHandleCount + _sessionHandles.Count; + handleLockTaken = _handleLock.TryEnterReadLock(Timeout.Infinite); + + portHandleCount = _ports.Count; + + handleCount = portHandleCount + _sessions.Count; handles = ArrayPool.Shared.Rent(handleCount); - _portHandles.CopyTo(handles, 0); - _sessionHandles.CopyTo(handles, portHandleCount); + _ports.Keys.CopyTo(handles, 0); + + _sessions.Keys.CopyTo(handles, portHandleCount); + } + finally + { + if (handleLockTaken) + { + _handleLock.ExitReadLock(); + } } // We still need a timeout here to allow the service to pick up and listen new sessions... var rc = _context.Syscall.ReplyAndReceive(out int signaledIndex, handles.AsSpan(0, handleCount), replyTargetHandle, 1000000L); - thread.HandlePostSyscall(); + _selfThread.HandlePostSyscall(); - if (!thread.Context.Running) + if (!_selfThread.Context.Running) { break; } @@ -178,9 +249,20 @@ namespace Ryujinx.HLE.HOS.Services // We got a new connection, accept the session to allow servicing future requests. if (_context.Syscall.AcceptSession(out int serverSessionHandle, handles[signaledIndex]) == Result.Success) { - IpcService obj = _ports[handles[signaledIndex]].Invoke(); - - AddSessionObj(serverSessionHandle, obj); + bool handleWriteLockTaken = false; + try + { + handleWriteLockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite); + IpcService obj = _ports[handles[signaledIndex]].Invoke(); + _sessions.Add(serverSessionHandle, obj); + } + finally + { + if (handleWriteLockTaken) + { + _handleLock.ExitWriteLock(); + } + } } } @@ -197,11 +279,7 @@ namespace Ryujinx.HLE.HOS.Services private bool Process(int serverSessionHandle, ulong recvListAddr) { - KProcess process = KernelStatic.GetCurrentProcess(); - KThread thread = KernelStatic.GetCurrentThread(); - ulong messagePtr = thread.TlsAddress; - - IpcMessage request = ReadRequest(process, messagePtr); + IpcMessage request = ReadRequest(); IpcMessage response = new IpcMessage(); @@ -247,15 +325,15 @@ namespace Ryujinx.HLE.HOS.Services ServiceCtx context = new ServiceCtx( _context.Device, - process, - process.CpuMemory, - thread, + _selfProcess, + _selfProcess.CpuMemory, + _selfThread, request, response, _requestDataReader, _responseDataWriter); - _sessions[serverSessionHandle].CallCmifMethod(context); + GetSessionObj(serverSessionHandle).CallCmifMethod(context); response.RawData = _responseDataStream.ToArray(); } @@ -268,7 +346,7 @@ namespace Ryujinx.HLE.HOS.Services switch (cmdId) { case 0: - FillHipcResponse(response, 0, _sessions[serverSessionHandle].ConvertToDomain()); + FillHipcResponse(response, 0, GetSessionObj(serverSessionHandle).ConvertToDomain()); break; case 3: @@ -278,17 +356,31 @@ namespace Ryujinx.HLE.HOS.Services // TODO: Whats the difference between IpcDuplicateSession/Ex? case 2: case 4: - int unknown = _requestDataReader.ReadInt32(); + { + _ = _requestDataReader.ReadInt32(); - _context.Syscall.CreateSession(out int dupServerSessionHandle, out int dupClientSessionHandle, false, 0); + _context.Syscall.CreateSession(out int dupServerSessionHandle, out int dupClientSessionHandle, false, 0); - AddSessionObj(dupServerSessionHandle, _sessions[serverSessionHandle]); + bool writeLockTaken = false; + try + { + writeLockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite); + _sessions[dupServerSessionHandle] = _sessions[serverSessionHandle]; + } + finally + { + if (writeLockTaken) + { + _handleLock.ExitWriteLock(); + } + } - response.HandleDesc = IpcHandleDesc.MakeMove(dupClientSessionHandle); + response.HandleDesc = IpcHandleDesc.MakeMove(dupClientSessionHandle); - FillHipcResponse(response, 0); + FillHipcResponse(response, 0); - break; + break; + } default: throw new NotImplementedException(cmdId.ToString()); } @@ -296,13 +388,10 @@ namespace Ryujinx.HLE.HOS.Services else if (request.Type == IpcMessageType.CmifCloseSession || request.Type == IpcMessageType.TipcCloseSession) { _context.Syscall.CloseHandle(serverSessionHandle); - lock (_handleLock) + if (RemoveSessionObj(serverSessionHandle, out var session)) { - _sessionHandles.Remove(serverSessionHandle); + (session as IDisposable)?.Dispose(); } - IpcService service = _sessions[serverSessionHandle]; - (service as IDisposable)?.Dispose(); - _sessions.Remove(serverSessionHandle); shouldReply = false; } // If the type is past 0xF, we are using TIPC @@ -317,20 +406,20 @@ namespace Ryujinx.HLE.HOS.Services ServiceCtx context = new ServiceCtx( _context.Device, - process, - process.CpuMemory, - thread, + _selfProcess, + _selfProcess.CpuMemory, + _selfThread, request, response, _requestDataReader, _responseDataWriter); - _sessions[serverSessionHandle].CallTipcMethod(context); + GetSessionObj(serverSessionHandle).CallTipcMethod(context); response.RawData = _responseDataStream.ToArray(); using var responseStream = response.GetStreamTipc(); - process.CpuMemory.Write(messagePtr, responseStream.GetReadOnlySequence()); + _selfProcess.CpuMemory.Write(_selfThread.TlsAddress, responseStream.GetReadOnlySequence()); } else { @@ -339,27 +428,24 @@ namespace Ryujinx.HLE.HOS.Services if (!isTipcCommunication) { - using var responseStream = response.GetStream((long)messagePtr, recvListAddr | ((ulong)PointerBufferSize << 48)); - process.CpuMemory.Write(messagePtr, responseStream.GetReadOnlySequence()); + using var responseStream = response.GetStream((long)_selfThread.TlsAddress, recvListAddr | ((ulong)PointerBufferSize << 48)); + _selfProcess.CpuMemory.Write(_selfThread.TlsAddress, responseStream.GetReadOnlySequence()); } return shouldReply; } - private static IpcMessage ReadRequest(KProcess process, ulong messagePtr) + private IpcMessage ReadRequest() { const int messageSize = 0x100; - byte[] reqData = ArrayPool.Shared.Rent(messageSize); + using IMemoryOwner reqDataOwner = ByteMemoryPool.Shared.Rent(messageSize); - Span reqDataSpan = reqData.AsSpan(0, messageSize); - reqDataSpan.Clear(); + Span reqDataSpan = reqDataOwner.Memory.Span; - process.CpuMemory.Read(messagePtr, reqDataSpan); + _selfProcess.CpuMemory.Read(_selfThread.TlsAddress, reqDataSpan); - IpcMessage request = new IpcMessage(reqDataSpan, (long)messagePtr); - - ArrayPool.Shared.Return(reqData); + IpcMessage request = new IpcMessage(reqDataSpan, (long)_selfThread.TlsAddress); return request; } @@ -392,26 +478,35 @@ namespace Ryujinx.HLE.HOS.Services protected virtual void Dispose(bool disposing) { - if (disposing) + if (disposing && _selfThread != null) { - foreach (IpcService service in _sessions.Values) + if (_selfThread.HostThread.ManagedThreadId != Environment.CurrentManagedThreadId && _selfThread.HostThread.Join(ThreadJoinTimeout) == false) { - if (service is IDisposable disposableObj) - { - disposableObj.Dispose(); - } + Logger.Warning?.Print(LogClass.Service, $"The ServerBase thread didn't terminate within {ThreadJoinTimeout:g}, waiting longer."); - service.DestroyAtExit(); + _selfThread.HostThread.Join(Timeout.Infinite); } - _sessions.Clear(); + if (Interlocked.Exchange(ref _isDisposed, 1) == 0) + { + foreach (IpcService service in _sessions.Values) + { + (service as IDisposable)?.Dispose(); - _requestDataReader.Dispose(); - _requestDataStream.Dispose(); - _responseDataWriter.Dispose(); - _responseDataStream.Dispose(); + service.DestroyAtExit(); + } - InitDone.Dispose(); + _sessions.Clear(); + _ports.Clear(); + _handleLock.Dispose(); + + _requestDataReader.Dispose(); + _requestDataStream.Dispose(); + _responseDataWriter.Dispose(); + _responseDataStream.Dispose(); + + InitDone.Dispose(); + } } }