ServerBase thread safety (#4577)

* Add guard against ServerBase.Dispose() being called multiple times. Add reset event to avoid Dispose() being called while the ServerLoop is still running.

* remove unused usings

* rework ServerBase to use one collection each for sessions and ports, and make all accesses thread-safe.

* fix Logger call

* use GetSessionObj(int) instead of using _sessions directly

* move _threadStopped check inside "dispose once" test

* - Replace _threadStopped event with attempt to Join() the ending thread (if that isn't the current thread) instead.

- Use the instance-local _selfProcess and (new) _selfThread variables to avoid suggesting that the current KProcess and KThread could change. Per gdkchan, they can't currently, and this old IPC system will be removed before that changes.

- Re-order Dispose() so that the Interlocked _isDisposed check is the last check before disposing, to increase the likelihood that multiple callers will result in one of them succeeding.

* code style suggestions per AcK77

* add infinite wait for thread termination
This commit is contained in:
jhorv 2023-05-21 15:28:51 -04:00 committed by GitHub
parent 5626f2ca1c
commit 21e88f17f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,5 @@
using Ryujinx.Common;
using Ryujinx.Common.Logging;
using Ryujinx.Common.Memory;
using Ryujinx.HLE.HOS.Ipc;
using Ryujinx.HLE.HOS.Kernel;
@ -32,13 +33,14 @@ namespace Ryujinx.HLE.HOS.Services
0x01007FFF
};
private readonly object _handleLock = new();
// The amount of time Dispose() will wait to Join() the thread executing the ServerLoop()
private static readonly TimeSpan ThreadJoinTimeout = TimeSpan.FromSeconds(3);
private readonly KernelContext _context;
private KProcess _selfProcess;
private KThread _selfThread;
private readonly List<int> _sessionHandles = new List<int>();
private readonly List<int> _portHandles = new List<int>();
private readonly ReaderWriterLockSlim _handleLock = new ReaderWriterLockSlim();
private readonly Dictionary<int, IpcService> _sessions = new Dictionary<int, IpcService>();
private readonly Dictionary<int, Func<IpcService>> _ports = new Dictionary<int, Func<IpcService>>();
@ -48,6 +50,8 @@ namespace Ryujinx.HLE.HOS.Services
private readonly MemoryStream _responseDataStream;
private readonly BinaryWriter _responseDataWriter;
private int _isDisposed = 0;
public ManualResetEvent InitDone { get; }
public string Name { get; }
public Func<IpcService> SmObjectFactory { get; }
@ -79,11 +83,20 @@ namespace Ryujinx.HLE.HOS.Services
private void AddPort(int serverPortHandle, Func<IpcService> objectFactory)
{
lock (_handleLock)
bool lockTaken = false;
try
{
_portHandles.Add(serverPortHandle);
lockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite);
_ports.Add(serverPortHandle, objectFactory);
}
finally
{
if (lockTaken)
{
_handleLock.ExitWriteLock();
}
}
_ports.Add(serverPortHandle, objectFactory);
}
public void AddSessionObj(KServerSession serverSession, IpcService obj)
@ -92,16 +105,62 @@ namespace Ryujinx.HLE.HOS.Services
InitDone.WaitOne();
_selfProcess.HandleTable.GenerateHandle(serverSession, out int serverSessionHandle);
AddSessionObj(serverSessionHandle, obj);
}
public void AddSessionObj(int serverSessionHandle, IpcService obj)
{
lock (_handleLock)
bool lockTaken = false;
try
{
_sessionHandles.Add(serverSessionHandle);
lockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite);
_sessions.Add(serverSessionHandle, obj);
}
finally
{
if (lockTaken)
{
_handleLock.ExitWriteLock();
}
}
}
private IpcService GetSessionObj(int serverSessionHandle)
{
bool lockTaken = false;
try
{
lockTaken = _handleLock.TryEnterReadLock(Timeout.Infinite);
return _sessions[serverSessionHandle];
}
finally
{
if (lockTaken)
{
_handleLock.ExitReadLock();
}
}
}
private bool RemoveSessionObj(int serverSessionHandle, out IpcService obj)
{
bool lockTaken = false;
try
{
lockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite);
return _sessions.Remove(serverSessionHandle, out obj);
}
finally
{
if (lockTaken)
{
_handleLock.ExitWriteLock();
}
}
_sessions.Add(serverSessionHandle, obj);
}
private void Main()
@ -112,6 +171,7 @@ namespace Ryujinx.HLE.HOS.Services
private void ServerLoop()
{
_selfProcess = KernelStatic.GetCurrentProcess();
_selfThread = KernelStatic.GetCurrentThread();
if (SmObjectFactory != null)
{
@ -122,8 +182,7 @@ namespace Ryujinx.HLE.HOS.Services
InitDone.Set();
KThread thread = KernelStatic.GetCurrentThread();
ulong messagePtr = thread.TlsAddress;
ulong messagePtr = _selfThread.TlsAddress;
_context.Syscall.SetHeapSize(out ulong heapAddr, 0x200000);
_selfProcess.CpuMemory.Write(messagePtr + 0x0, 0);
@ -134,27 +193,39 @@ namespace Ryujinx.HLE.HOS.Services
while (true)
{
int handleCount;
int portHandleCount;
int handleCount;
int[] handles;
lock (_handleLock)
bool handleLockTaken = false;
try
{
portHandleCount = _portHandles.Count;
handleCount = portHandleCount + _sessionHandles.Count;
handleLockTaken = _handleLock.TryEnterReadLock(Timeout.Infinite);
portHandleCount = _ports.Count;
handleCount = portHandleCount + _sessions.Count;
handles = ArrayPool<int>.Shared.Rent(handleCount);
_portHandles.CopyTo(handles, 0);
_sessionHandles.CopyTo(handles, portHandleCount);
_ports.Keys.CopyTo(handles, 0);
_sessions.Keys.CopyTo(handles, portHandleCount);
}
finally
{
if (handleLockTaken)
{
_handleLock.ExitReadLock();
}
}
// We still need a timeout here to allow the service to pick up and listen new sessions...
var rc = _context.Syscall.ReplyAndReceive(out int signaledIndex, handles.AsSpan(0, handleCount), replyTargetHandle, 1000000L);
thread.HandlePostSyscall();
_selfThread.HandlePostSyscall();
if (!thread.Context.Running)
if (!_selfThread.Context.Running)
{
break;
}
@ -178,9 +249,20 @@ namespace Ryujinx.HLE.HOS.Services
// We got a new connection, accept the session to allow servicing future requests.
if (_context.Syscall.AcceptSession(out int serverSessionHandle, handles[signaledIndex]) == Result.Success)
{
IpcService obj = _ports[handles[signaledIndex]].Invoke();
AddSessionObj(serverSessionHandle, obj);
bool handleWriteLockTaken = false;
try
{
handleWriteLockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite);
IpcService obj = _ports[handles[signaledIndex]].Invoke();
_sessions.Add(serverSessionHandle, obj);
}
finally
{
if (handleWriteLockTaken)
{
_handleLock.ExitWriteLock();
}
}
}
}
@ -197,11 +279,7 @@ namespace Ryujinx.HLE.HOS.Services
private bool Process(int serverSessionHandle, ulong recvListAddr)
{
KProcess process = KernelStatic.GetCurrentProcess();
KThread thread = KernelStatic.GetCurrentThread();
ulong messagePtr = thread.TlsAddress;
IpcMessage request = ReadRequest(process, messagePtr);
IpcMessage request = ReadRequest();
IpcMessage response = new IpcMessage();
@ -247,15 +325,15 @@ namespace Ryujinx.HLE.HOS.Services
ServiceCtx context = new ServiceCtx(
_context.Device,
process,
process.CpuMemory,
thread,
_selfProcess,
_selfProcess.CpuMemory,
_selfThread,
request,
response,
_requestDataReader,
_responseDataWriter);
_sessions[serverSessionHandle].CallCmifMethod(context);
GetSessionObj(serverSessionHandle).CallCmifMethod(context);
response.RawData = _responseDataStream.ToArray();
}
@ -268,7 +346,7 @@ namespace Ryujinx.HLE.HOS.Services
switch (cmdId)
{
case 0:
FillHipcResponse(response, 0, _sessions[serverSessionHandle].ConvertToDomain());
FillHipcResponse(response, 0, GetSessionObj(serverSessionHandle).ConvertToDomain());
break;
case 3:
@ -278,17 +356,31 @@ namespace Ryujinx.HLE.HOS.Services
// TODO: Whats the difference between IpcDuplicateSession/Ex?
case 2:
case 4:
int unknown = _requestDataReader.ReadInt32();
{
_ = _requestDataReader.ReadInt32();
_context.Syscall.CreateSession(out int dupServerSessionHandle, out int dupClientSessionHandle, false, 0);
_context.Syscall.CreateSession(out int dupServerSessionHandle, out int dupClientSessionHandle, false, 0);
AddSessionObj(dupServerSessionHandle, _sessions[serverSessionHandle]);
bool writeLockTaken = false;
try
{
writeLockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite);
_sessions[dupServerSessionHandle] = _sessions[serverSessionHandle];
}
finally
{
if (writeLockTaken)
{
_handleLock.ExitWriteLock();
}
}
response.HandleDesc = IpcHandleDesc.MakeMove(dupClientSessionHandle);
response.HandleDesc = IpcHandleDesc.MakeMove(dupClientSessionHandle);
FillHipcResponse(response, 0);
FillHipcResponse(response, 0);
break;
break;
}
default: throw new NotImplementedException(cmdId.ToString());
}
@ -296,13 +388,10 @@ namespace Ryujinx.HLE.HOS.Services
else if (request.Type == IpcMessageType.CmifCloseSession || request.Type == IpcMessageType.TipcCloseSession)
{
_context.Syscall.CloseHandle(serverSessionHandle);
lock (_handleLock)
if (RemoveSessionObj(serverSessionHandle, out var session))
{
_sessionHandles.Remove(serverSessionHandle);
(session as IDisposable)?.Dispose();
}
IpcService service = _sessions[serverSessionHandle];
(service as IDisposable)?.Dispose();
_sessions.Remove(serverSessionHandle);
shouldReply = false;
}
// If the type is past 0xF, we are using TIPC
@ -317,20 +406,20 @@ namespace Ryujinx.HLE.HOS.Services
ServiceCtx context = new ServiceCtx(
_context.Device,
process,
process.CpuMemory,
thread,
_selfProcess,
_selfProcess.CpuMemory,
_selfThread,
request,
response,
_requestDataReader,
_responseDataWriter);
_sessions[serverSessionHandle].CallTipcMethod(context);
GetSessionObj(serverSessionHandle).CallTipcMethod(context);
response.RawData = _responseDataStream.ToArray();
using var responseStream = response.GetStreamTipc();
process.CpuMemory.Write(messagePtr, responseStream.GetReadOnlySequence());
_selfProcess.CpuMemory.Write(_selfThread.TlsAddress, responseStream.GetReadOnlySequence());
}
else
{
@ -339,27 +428,24 @@ namespace Ryujinx.HLE.HOS.Services
if (!isTipcCommunication)
{
using var responseStream = response.GetStream((long)messagePtr, recvListAddr | ((ulong)PointerBufferSize << 48));
process.CpuMemory.Write(messagePtr, responseStream.GetReadOnlySequence());
using var responseStream = response.GetStream((long)_selfThread.TlsAddress, recvListAddr | ((ulong)PointerBufferSize << 48));
_selfProcess.CpuMemory.Write(_selfThread.TlsAddress, responseStream.GetReadOnlySequence());
}
return shouldReply;
}
private static IpcMessage ReadRequest(KProcess process, ulong messagePtr)
private IpcMessage ReadRequest()
{
const int messageSize = 0x100;
byte[] reqData = ArrayPool<byte>.Shared.Rent(messageSize);
using IMemoryOwner<byte> reqDataOwner = ByteMemoryPool.Shared.Rent(messageSize);
Span<byte> reqDataSpan = reqData.AsSpan(0, messageSize);
reqDataSpan.Clear();
Span<byte> reqDataSpan = reqDataOwner.Memory.Span;
process.CpuMemory.Read(messagePtr, reqDataSpan);
_selfProcess.CpuMemory.Read(_selfThread.TlsAddress, reqDataSpan);
IpcMessage request = new IpcMessage(reqDataSpan, (long)messagePtr);
ArrayPool<byte>.Shared.Return(reqData);
IpcMessage request = new IpcMessage(reqDataSpan, (long)_selfThread.TlsAddress);
return request;
}
@ -392,26 +478,35 @@ namespace Ryujinx.HLE.HOS.Services
protected virtual void Dispose(bool disposing)
{
if (disposing)
if (disposing && _selfThread != null)
{
foreach (IpcService service in _sessions.Values)
if (_selfThread.HostThread.ManagedThreadId != Environment.CurrentManagedThreadId && _selfThread.HostThread.Join(ThreadJoinTimeout) == false)
{
if (service is IDisposable disposableObj)
{
disposableObj.Dispose();
}
Logger.Warning?.Print(LogClass.Service, $"The ServerBase thread didn't terminate within {ThreadJoinTimeout:g}, waiting longer.");
service.DestroyAtExit();
_selfThread.HostThread.Join(Timeout.Infinite);
}
_sessions.Clear();
if (Interlocked.Exchange(ref _isDisposed, 1) == 0)
{
foreach (IpcService service in _sessions.Values)
{
(service as IDisposable)?.Dispose();
_requestDataReader.Dispose();
_requestDataStream.Dispose();
_responseDataWriter.Dispose();
_responseDataStream.Dispose();
service.DestroyAtExit();
}
InitDone.Dispose();
_sessions.Clear();
_ports.Clear();
_handleLock.Dispose();
_requestDataReader.Dispose();
_requestDataStream.Dispose();
_responseDataWriter.Dispose();
_responseDataStream.Dispose();
InitDone.Dispose();
}
}
}