Bsd: Implement Select (#4017)

* bsd: Add gdkchan's Select implementation

Co-authored-by: TSRBerry <20988865+tsrberry@users.noreply.github.com>

* bsd: Fix Select() causing a crash with an ArgumentException

.NET Sockets have to be used for the Select() call

* bsd: Make Select more generic

* bsd: Adjust namespaces and remove unused imports

* bsd: Fix NullReferenceException in Select

Co-authored-by: gdkchan <gab.dark.100@gmail.com>
This commit is contained in:
TSRBerry 2022-12-12 14:59:31 +01:00 committed by GitHub
parent 403e67d983
commit ba5c0cf5d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 263 additions and 43 deletions

View File

@ -1,5 +1,8 @@
using System.Collections.Concurrent; using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Numerics;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
{ {
@ -41,6 +44,27 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
return null; return null;
} }
public List<IFileDescriptor> RetrieveFileDescriptorsFromMask(ReadOnlySpan<byte> mask)
{
List<IFileDescriptor> fds = new();
for (int i = 0; i < mask.Length; i++)
{
byte current = mask[i];
while (current != 0)
{
int bit = BitOperations.TrailingZeroCount(current);
current &= (byte)~(1 << bit);
int fd = i * 8 + bit;
fds.Add(RetrieveFileDescriptor(fd));
}
}
return fds;
}
public int RegisterFileDescriptor(IFileDescriptor file) public int RegisterFileDescriptor(IFileDescriptor file)
{ {
lock (_lock) lock (_lock)
@ -61,6 +85,16 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
} }
} }
public void BuildMask(List<IFileDescriptor> fds, Span<byte> mask)
{
foreach (IFileDescriptor descriptor in fds)
{
int fd = _fds.IndexOf(descriptor);
mask[fd >> 3] |= (byte)(1 << (fd & 7));
}
}
public int DuplicateFileDescriptor(int fd) public int DuplicateFileDescriptor(int fd)
{ {
IFileDescriptor oldFile = RetrieveFileDescriptor(fd); IFileDescriptor oldFile = RetrieveFileDescriptor(fd);
@ -147,4 +181,4 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
return processContext; return processContext;
} }
} }
} }

View File

@ -1,10 +1,13 @@
using Ryujinx.Common; using Ryujinx.Common;
using Ryujinx.Common.Logging; using Ryujinx.Common.Logging;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using Ryujinx.Memory; using Ryujinx.Memory;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Net; using System.Net;
using System.Net.Sockets; using System.Net.Sockets;
using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Text; using System.Text;
@ -202,12 +205,122 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
} }
[CommandHipc(5)] [CommandHipc(5)]
// Select(u32 nfds, nn::socket::timeout timeout, buffer<nn::socket::fd_set, 0x21, 0> readfds_in, buffer<nn::socket::fd_set, 0x21, 0> writefds_in, buffer<nn::socket::fd_set, 0x21, 0> errorfds_in) -> (i32 ret, u32 bsd_errno, buffer<nn::socket::fd_set, 0x22, 0> readfds_out, buffer<nn::socket::fd_set, 0x22, 0> writefds_out, buffer<nn::socket::fd_set, 0x22, 0> errorfds_out) // Select(u32 nfds, nn::socket::timeval timeout, buffer<nn::socket::fd_set, 0x21, 0> readfds_in, buffer<nn::socket::fd_set, 0x21, 0> writefds_in, buffer<nn::socket::fd_set, 0x21, 0> errorfds_in)
// -> (i32 ret, u32 bsd_errno, buffer<nn::socket::fd_set, 0x22, 0> readfds_out, buffer<nn::socket::fd_set, 0x22, 0> writefds_out, buffer<nn::socket::fd_set, 0x22, 0> errorfds_out)
public ResultCode Select(ServiceCtx context) public ResultCode Select(ServiceCtx context)
{ {
WriteBsdResult(context, -1, LinuxError.EOPNOTSUPP); int fdsCount = context.RequestData.ReadInt32();
int timeout = context.RequestData.ReadInt32();
Logger.Stub?.PrintStub(LogClass.ServiceBsd); (ulong readFdsInBufferPosition, ulong readFdsInBufferSize) = context.Request.GetBufferType0x21(0);
(ulong writeFdsInBufferPosition, ulong writeFdsInBufferSize) = context.Request.GetBufferType0x21(1);
(ulong errorFdsInBufferPosition, ulong errorFdsInBufferSize) = context.Request.GetBufferType0x21(2);
(ulong readFdsOutBufferPosition, ulong readFdsOutBufferSize) = context.Request.GetBufferType0x22(0);
(ulong writeFdsOutBufferPosition, ulong writeFdsOutBufferSize) = context.Request.GetBufferType0x22(1);
(ulong errorFdsOutBufferPosition, ulong errorFdsOutBufferSize) = context.Request.GetBufferType0x22(2);
List<IFileDescriptor> readFds = _context.RetrieveFileDescriptorsFromMask(context.Memory.GetSpan(readFdsInBufferPosition, (int)readFdsInBufferSize));
List<IFileDescriptor> writeFds = _context.RetrieveFileDescriptorsFromMask(context.Memory.GetSpan(writeFdsInBufferPosition, (int)writeFdsInBufferSize));
List<IFileDescriptor> errorFds = _context.RetrieveFileDescriptorsFromMask(context.Memory.GetSpan(errorFdsInBufferPosition, (int)errorFdsInBufferSize));
int actualFdsCount = readFds.Count + writeFds.Count + errorFds.Count;
if (fdsCount == 0 || actualFdsCount == 0)
{
WriteBsdResult(context, 0);
return ResultCode.Success;
}
PollEvent[] events = new PollEvent[actualFdsCount];
int index = 0;
foreach (IFileDescriptor fd in readFds)
{
events[index] = new PollEvent(new PollEventData { InputEvents = PollEventTypeMask.Input }, fd);
index++;
}
foreach (IFileDescriptor fd in writeFds)
{
events[index] = new PollEvent(new PollEventData { InputEvents = PollEventTypeMask.Output }, fd);
index++;
}
foreach (IFileDescriptor fd in errorFds)
{
events[index] = new PollEvent(new PollEventData { InputEvents = PollEventTypeMask.Error }, fd);
index++;
}
List<PollEvent>[] eventsByPollManager = new List<PollEvent>[_pollManagers.Count];
for (int i = 0; i < eventsByPollManager.Length; i++)
{
eventsByPollManager[i] = new List<PollEvent>();
foreach (PollEvent evnt in events)
{
if (_pollManagers[i].IsCompatible(evnt))
{
eventsByPollManager[i].Add(evnt);
}
}
}
int updatedCount = 0;
for (int i = 0; i < _pollManagers.Count; i++)
{
if (eventsByPollManager[i].Count > 0)
{
_pollManagers[i].Select(eventsByPollManager[i], timeout, out int updatedPollCount);
updatedCount += updatedPollCount;
}
}
readFds.Clear();
writeFds.Clear();
errorFds.Clear();
foreach (PollEvent pollEvent in events)
{
for (int i = 0; i < _pollManagers.Count; i++)
{
if (eventsByPollManager[i].Contains(pollEvent))
{
if (pollEvent.Data.OutputEvents.HasFlag(PollEventTypeMask.Input))
{
readFds.Add(pollEvent.FileDescriptor);
}
if (pollEvent.Data.OutputEvents.HasFlag(PollEventTypeMask.Output))
{
writeFds.Add(pollEvent.FileDescriptor);
}
if (pollEvent.Data.OutputEvents.HasFlag(PollEventTypeMask.Error))
{
errorFds.Add(pollEvent.FileDescriptor);
}
}
}
}
using var readFdsOut = context.Memory.GetWritableRegion(readFdsOutBufferPosition, (int)readFdsOutBufferSize);
using var writeFdsOut = context.Memory.GetWritableRegion(writeFdsOutBufferPosition, (int)writeFdsOutBufferSize);
using var errorFdsOut = context.Memory.GetWritableRegion(errorFdsOutBufferPosition, (int)errorFdsOutBufferSize);
_context.BuildMask(readFds, readFdsOut.Memory.Span);
_context.BuildMask(writeFds, writeFdsOut.Memory.Span);
_context.BuildMask(errorFds, errorFdsOut.Memory.Span);
WriteBsdResult(context, updatedCount);
return ResultCode.Success; return ResultCode.Success;
} }
@ -320,14 +433,14 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
break; break;
} }
// If we are here, that mean nothing was availaible, sleep for 50ms // If we are here, that mean nothing was available, sleep for 50ms
context.Device.System.KernelContext.Syscall.SleepThread(50 * 1000000); context.Device.System.KernelContext.Syscall.SleepThread(50 * 1000000);
} }
while (PerformanceCounter.ElapsedMilliseconds < budgetLeftMilliseconds); while (PerformanceCounter.ElapsedMilliseconds < budgetLeftMilliseconds);
} }
else if (timeout == -1) else if (timeout == -1)
{ {
// FIXME: If we get a timeout of -1 and there is no fds to wait on, this should kill the KProces. (need to check that with re) // FIXME: If we get a timeout of -1 and there is no fds to wait on, this should kill the KProcess. (need to check that with re)
throw new InvalidOperationException(); throw new InvalidOperationException();
} }
else else
@ -998,4 +1111,4 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
return WriteBsdResult(context, newSockFd, errno); return WriteBsdResult(context, newSockFd, errno);
} }
} }
} }

View File

@ -1,4 +1,5 @@
using System; using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using System;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
{ {
@ -11,4 +12,4 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
LinuxError Write(out int writeSize, ReadOnlySpan<byte> buffer); LinuxError Write(out int writeSize, ReadOnlySpan<byte> buffer);
} }
} }

View File

@ -1,4 +1,5 @@
using System; using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using System;
using System.Net; using System.Net;
using System.Net.Sockets; using System.Net.Sockets;

View File

@ -1,8 +1,9 @@
using System; using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using System;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Threading; using System.Threading;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
{ {
class EventFileDescriptor : IFileDescriptor class EventFileDescriptor : IFileDescriptor
{ {
@ -149,4 +150,4 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
} }
} }
} }
} }

View File

@ -1,8 +1,9 @@
using Ryujinx.Common.Logging; using Ryujinx.Common.Logging;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using System.Collections.Generic; using System.Collections.Generic;
using System.Threading; using System.Threading;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
{ {
class EventFileDescriptorPollManager : IPollManager class EventFileDescriptorPollManager : IPollManager
{ {
@ -109,5 +110,13 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
return LinuxError.SUCCESS; return LinuxError.SUCCESS;
} }
public LinuxError Select(List<PollEvent> events, int timeout, out int updatedCount)
{
// TODO: Implement Select for event file descriptors
updatedCount = 0;
return LinuxError.EOPNOTSUPP;
}
} }
} }

View File

@ -1,4 +1,5 @@
using Ryujinx.Common.Logging; using Ryujinx.Common.Logging;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
@ -6,7 +7,7 @@ using System.Net;
using System.Net.Sockets; using System.Net.Sockets;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
{ {
class ManagedSocket : ISocket class ManagedSocket : ISocket
{ {

View File

@ -1,8 +1,9 @@
using Ryujinx.Common.Logging; using Ryujinx.Common.Logging;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using System.Collections.Generic; using System.Collections.Generic;
using System.Net.Sockets; using System.Net.Sockets;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
{ {
class ManagedSocketPollManager : IPollManager class ManagedSocketPollManager : IPollManager
{ {
@ -117,5 +118,60 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
return LinuxError.SUCCESS; return LinuxError.SUCCESS;
} }
public LinuxError Select(List<PollEvent> events, int timeout, out int updatedCount)
{
List<Socket> readEvents = new();
List<Socket> writeEvents = new();
List<Socket> errorEvents = new();
updatedCount = 0;
foreach (PollEvent pollEvent in events)
{
ManagedSocket socket = (ManagedSocket)pollEvent.FileDescriptor;
if (pollEvent.Data.InputEvents.HasFlag(PollEventTypeMask.Input))
{
readEvents.Add(socket.Socket);
}
if (pollEvent.Data.InputEvents.HasFlag(PollEventTypeMask.Output))
{
writeEvents.Add(socket.Socket);
}
if (pollEvent.Data.InputEvents.HasFlag(PollEventTypeMask.Error))
{
errorEvents.Add(socket.Socket);
}
}
Socket.Select(readEvents, writeEvents, errorEvents, timeout);
updatedCount = readEvents.Count + writeEvents.Count + errorEvents.Count;
foreach (PollEvent pollEvent in events)
{
ManagedSocket socket = (ManagedSocket)pollEvent.FileDescriptor;
if (readEvents.Contains(socket.Socket))
{
pollEvent.Data.OutputEvents |= PollEventTypeMask.Input;
}
if (writeEvents.Contains(socket.Socket))
{
pollEvent.Data.OutputEvents |= PollEventTypeMask.Output;
}
if (errorEvents.Contains(socket.Socket))
{
pollEvent.Data.OutputEvents |= PollEventTypeMask.Error;
}
}
return LinuxError.SUCCESS;
}
} }
} }

View File

@ -1,6 +1,6 @@
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
{ {
[SuppressMessage("ReSharper", "InconsistentNaming")] [SuppressMessage("ReSharper", "InconsistentNaming")]
enum WsaError enum WsaError

View File

@ -1,7 +1,8 @@
using System.Collections.Generic; using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using System.Collections.Generic;
using System.Net.Sockets; using System.Net.Sockets;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
{ {
static class WinSockHelper static class WinSockHelper
{ {
@ -162,4 +163,4 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
return table.TryGetValue(option, out name); return table.TryGetValue(option, out name);
} }
} }
} }

View File

@ -1,4 +1,4 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{ {
enum BsdAddressFamily : uint enum BsdAddressFamily : uint
{ {
@ -8,4 +8,4 @@
Unknown = uint.MaxValue Unknown = uint.MaxValue
} }
} }

View File

@ -1,4 +1,4 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{ {
enum BsdIoctl enum BsdIoctl
{ {

View File

@ -1,6 +1,6 @@
using System; using System;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{ {
class BsdMMsgHdr class BsdMMsgHdr
{ {

View File

@ -1,7 +1,7 @@
using System; using System;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{ {
class BsdMsgHdr class BsdMsgHdr
{ {

View File

@ -3,7 +3,7 @@ using System;
using System.Net; using System.Net;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{ {
[StructLayout(LayoutKind.Sequential, Pack = 1, Size = 0x10)] [StructLayout(LayoutKind.Sequential, Pack = 1, Size = 0x10)]
struct BsdSockAddr struct BsdSockAddr

View File

@ -1,6 +1,6 @@
using System; using System;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{ {
[Flags] [Flags]
enum BsdSocketCreationFlags enum BsdSocketCreationFlags

View File

@ -1,4 +1,4 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{ {
enum BsdSocketFlags enum BsdSocketFlags
{ {

View File

@ -1,4 +1,4 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{ {
enum BsdSocketOption enum BsdSocketOption
{ {

View File

@ -1,4 +1,4 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{ {
enum BsdSocketShutdownFlags enum BsdSocketShutdownFlags
{ {

View File

@ -1,4 +1,4 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{ {
enum BsdSocketType enum BsdSocketType
{ {

View File

@ -1,6 +1,6 @@
using System; using System;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{ {
[Flags] [Flags]
enum EventFdFlags : uint enum EventFdFlags : uint

View File

@ -1,11 +1,13 @@
using System.Collections.Generic; using System.Collections.Generic;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{ {
interface IPollManager interface IPollManager
{ {
bool IsCompatible(PollEvent evnt); bool IsCompatible(PollEvent evnt);
LinuxError Poll(List<PollEvent> events, int timeoutMilliseconds, out int updatedCount); LinuxError Poll(List<PollEvent> events, int timeoutMilliseconds, out int updatedCount);
LinuxError Select(List<PollEvent> events, int timeout, out int updatedCount);
} }
} }

View File

@ -1,6 +1,6 @@
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{ {
[SuppressMessage("ReSharper", "InconsistentNaming")] [SuppressMessage("ReSharper", "InconsistentNaming")]
enum LinuxError enum LinuxError

View File

@ -1,4 +1,4 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{ {
class PollEvent class PollEvent
{ {

View File

@ -1,4 +1,4 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{ {
struct PollEventData struct PollEventData
{ {
@ -8,4 +8,4 @@
#pragma warning restore CS0649 #pragma warning restore CS0649
public PollEventTypeMask OutputEvents; public PollEventTypeMask OutputEvents;
} }
} }

View File

@ -1,6 +1,6 @@
using System; using System;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{ {
[Flags] [Flags]
enum PollEventTypeMask : ushort enum PollEventTypeMask : ushort

View File

@ -1,4 +1,4 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{ {
public struct TimeVal public struct TimeVal
{ {

View File

@ -1,4 +1,5 @@
using Ryujinx.HLE.HOS.Services.Sockets.Bsd; using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl;
using Ryujinx.HLE.HOS.Services.Ssl.Types; using Ryujinx.HLE.HOS.Services.Ssl.Types;
using System; using System;
using System.IO; using System.IO;