ryujinx/Ryujinx.Horizon/Sdk/Sf/Hipc/ServerManagerBase.cs
2023-03-04 14:43:08 +01:00

315 lines
9.5 KiB
C#

using Ryujinx.Horizon.Common;
using Ryujinx.Horizon.Sdk.OsTypes;
using Ryujinx.Horizon.Sdk.Sf.Cmif;
using Ryujinx.Horizon.Sdk.Sm;
using System;
namespace Ryujinx.Horizon.Sdk.Sf.Hipc
{
class ServerManagerBase : ServerDomainSessionManager
{
private readonly SmApi _sm;
private bool _canDeferInvokeRequest;
private readonly MultiWait _multiWait;
private readonly MultiWait _waitList;
private readonly object _multiWaitSelectionLock;
private readonly object _waitListLock;
private readonly Event _requestStopEvent;
private readonly Event _notifyEvent;
private readonly MultiWaitHolderBase _requestStopEventHolder;
private readonly MultiWaitHolderBase _notifyEventHolder;
private enum UserDataTag
{
Server = 1,
Session = 2
}
public ServerManagerBase(SmApi sm, ManagerOptions options) : base(options.MaxDomainObjects, options.MaxDomains)
{
_sm = sm;
_canDeferInvokeRequest = options.CanDeferInvokeRequest;
_multiWait = new MultiWait();
_waitList = new MultiWait();
_multiWaitSelectionLock = new object();
_waitListLock = new object();
_requestStopEvent = new Event(EventClearMode.ManualClear);
_notifyEvent = new Event(EventClearMode.ManualClear);
_requestStopEventHolder = new MultiWaitHolderOfEvent(_requestStopEvent);
_multiWait.LinkMultiWaitHolder(_requestStopEventHolder);
_notifyEventHolder = new MultiWaitHolderOfEvent(_notifyEvent);
_multiWait.LinkMultiWaitHolder(_notifyEventHolder);
}
public void RegisterObjectForServer(IServiceObject staticObject, int portHandle)
{
RegisterServerImpl(0, new ServiceObjectHolder(staticObject), portHandle);
}
public Result RegisterObjectForServer(IServiceObject staticObject, ServiceName name, int maxSessions)
{
return RegisterServerImpl(0, new ServiceObjectHolder(staticObject), name, maxSessions);
}
public void RegisterServer(int portIndex, int portHandle)
{
RegisterServerImpl(portIndex, null, portHandle);
}
public Result RegisterServer(int portIndex, ServiceName name, int maxSessions)
{
return RegisterServerImpl(portIndex, null, name, maxSessions);
}
private void RegisterServerImpl(int portIndex, ServiceObjectHolder staticHolder, int portHandle)
{
Server server = AllocateServer(portIndex, portHandle, ServiceName.Invalid, managed: false, staticHolder);
RegisterServerImpl(server);
}
private Result RegisterServerImpl(int portIndex, ServiceObjectHolder staticHolder, ServiceName name, int maxSessions)
{
Result result = _sm.RegisterService(out int portHandle, name, maxSessions, isLight: false);
if (result.IsFailure)
{
return result;
}
Server server = AllocateServer(portIndex, portHandle, name, managed: true, staticHolder);
RegisterServerImpl(server);
return Result.Success;
}
private void RegisterServerImpl(Server server)
{
server.UserData = UserDataTag.Server;
_multiWait.LinkMultiWaitHolder(server);
}
protected virtual Result OnNeedsToAccept(int portIndex, Server server)
{
throw new NotSupportedException();
}
protected Result AcceptImpl(Server server, IServiceObject obj)
{
return AcceptSession(server.PortHandle, new ServiceObjectHolder(obj));
}
public void ServiceRequests()
{
while (WaitAndProcessRequestsImpl());
}
public void WaitAndProcessRequests()
{
WaitAndProcessRequestsImpl();
}
private bool WaitAndProcessRequestsImpl()
{
try
{
MultiWaitHolder multiWait = WaitSignaled();
if (multiWait == null)
{
return false;
}
DebugUtil.Assert(Process(multiWait).IsSuccess);
return HorizonStatic.ThreadContext.Running;
}
catch (ThreadTerminatedException)
{
return false;
}
}
private MultiWaitHolder WaitSignaled()
{
lock (_multiWaitSelectionLock)
{
while (true)
{
ProcessWaitList();
MultiWaitHolder selected = _multiWait.WaitAny();
if (selected == _requestStopEventHolder)
{
return null;
}
else if (selected == _notifyEventHolder)
{
_notifyEvent.Clear();
}
else
{
selected.UnlinkFromMultiWaitHolder();
return selected;
}
}
}
}
public void ResumeProcessing()
{
_requestStopEvent.Clear();
}
public void RequestStopProcessing()
{
_requestStopEvent.Signal();
}
protected override void RegisterSessionToWaitList(ServerSession session)
{
session.HasReceived = false;
session.UserData = UserDataTag.Session;
RegisterToWaitList(session);
}
private void RegisterToWaitList(MultiWaitHolder holder)
{
lock (_waitListLock)
{
_waitList.LinkMultiWaitHolder(holder);
_notifyEvent.Signal();
}
}
private void ProcessWaitList()
{
lock (_waitListLock)
{
_multiWait.MoveAllFrom(_waitList);
}
}
private Result Process(MultiWaitHolder holder)
{
return (UserDataTag)holder.UserData switch
{
UserDataTag.Server => ProcessForServer(holder),
UserDataTag.Session => ProcessForSession(holder),
_ => throw new NotImplementedException(((UserDataTag)holder.UserData).ToString())
};
}
private Result ProcessForServer(MultiWaitHolder holder)
{
DebugUtil.Assert((UserDataTag)holder.UserData == UserDataTag.Server);
Server server = (Server)holder;
try
{
if (server.StaticObject != null)
{
return AcceptSession(server.PortHandle, server.StaticObject.Clone());
}
else
{
return OnNeedsToAccept(server.PortIndex, server);
}
}
finally
{
RegisterToWaitList(server);
}
}
private Result ProcessForSession(MultiWaitHolder holder)
{
DebugUtil.Assert((UserDataTag)holder.UserData == UserDataTag.Session);
ServerSession session = (ServerSession)holder;
using var tlsMessage = HorizonStatic.AddressSpace.GetWritableRegion(HorizonStatic.ThreadContext.TlsAddress, Api.TlsMessageBufferSize);
Result result;
if (_canDeferInvokeRequest)
{
// If the request is deferred, we save the message on a temporary buffer to process it later.
using var savedMessage = HorizonStatic.AddressSpace.GetWritableRegion(session.SavedMessage.Address, (int)session.SavedMessage.Size);
DebugUtil.Assert(tlsMessage.Memory.Length == savedMessage.Memory.Length);
if (!session.HasReceived)
{
result = ReceiveRequest(session, tlsMessage.Memory.Span);
if (result.IsFailure)
{
return result;
}
session.HasReceived = true;
tlsMessage.Memory.Span.CopyTo(savedMessage.Memory.Span);
}
else
{
savedMessage.Memory.Span.CopyTo(tlsMessage.Memory.Span);
}
result = ProcessRequest(session, tlsMessage.Memory.Span);
if (result.IsFailure && !SfResult.Invalidated(result))
{
return result;
}
}
else
{
if (!session.HasReceived)
{
result = ReceiveRequest(session, tlsMessage.Memory.Span);
if (result.IsFailure)
{
return result;
}
session.HasReceived = true;
}
result = ProcessRequest(session, tlsMessage.Memory.Span);
if (result.IsFailure)
{
// Those results are not valid because the service does not support deferral.
if (SfResult.RequestDeferred(result) || SfResult.Invalidated(result))
{
result.AbortOnFailure();
}
return result;
}
}
return Result.Success;
}
}
}