From 8544b1445b33381fca63714249ac36598c413004 Mon Sep 17 00:00:00 2001 From: gdkchan Date: Wed, 29 Dec 2021 11:04:38 -0300 Subject: [PATCH] Improve SocketOption handling (#2946) --- .../HOS/Services/Sockets/Bsd/IClient.cs | 180 ++++++++++-------- .../Sockets/Bsd/Types/BsdSocketFlags.cs | 2 - .../Sockets/Bsd/Types/BsdSocketOption.cs | 119 ++++++++++++ 3 files changed, 221 insertions(+), 80 deletions(-) create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketOption.cs diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs index 43c7ee7d..76f80f92 100644 --- a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs @@ -13,7 +13,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd [Service("bsd:u", false)] class IClient : IpcService { - private static Dictionary _errorMap = new Dictionary + private static readonly Dictionary _errorMap = new() { // WSAEINTR {WsaError.WSAEINTR, LinuxError.EINTR}, @@ -97,6 +97,50 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd {0, 0} }; + private static readonly Dictionary _soSocketOptionMap = new() + { + { BsdSocketOption.SoDebug, SocketOptionName.Debug }, + { BsdSocketOption.SoReuseAddr, SocketOptionName.ReuseAddress }, + { BsdSocketOption.SoKeepAlive, SocketOptionName.KeepAlive }, + { BsdSocketOption.SoDontRoute, SocketOptionName.DontRoute }, + { BsdSocketOption.SoBroadcast, SocketOptionName.Broadcast }, + { BsdSocketOption.SoUseLoopBack, SocketOptionName.UseLoopback }, + { BsdSocketOption.SoLinger, SocketOptionName.Linger }, + { BsdSocketOption.SoOobInline, SocketOptionName.OutOfBandInline }, + { BsdSocketOption.SoReusePort, SocketOptionName.ReuseAddress }, + { BsdSocketOption.SoSndBuf, SocketOptionName.SendBuffer }, + { BsdSocketOption.SoRcvBuf, SocketOptionName.ReceiveBuffer }, + { BsdSocketOption.SoSndLoWat, SocketOptionName.SendLowWater }, + { BsdSocketOption.SoRcvLoWat, SocketOptionName.ReceiveLowWater }, + { BsdSocketOption.SoSndTimeo, SocketOptionName.SendTimeout }, + { BsdSocketOption.SoRcvTimeo, SocketOptionName.ReceiveTimeout }, + { BsdSocketOption.SoError, SocketOptionName.Error }, + { BsdSocketOption.SoType, SocketOptionName.Type } + }; + + private static readonly Dictionary _ipSocketOptionMap = new() + { + { BsdSocketOption.IpOptions, SocketOptionName.IPOptions }, + { BsdSocketOption.IpHdrIncl, SocketOptionName.HeaderIncluded }, + { BsdSocketOption.IpTtl, SocketOptionName.IpTimeToLive }, + { BsdSocketOption.IpMulticastIf, SocketOptionName.MulticastInterface }, + { BsdSocketOption.IpMulticastTtl, SocketOptionName.MulticastTimeToLive }, + { BsdSocketOption.IpMulticastLoop, SocketOptionName.MulticastLoopback }, + { BsdSocketOption.IpAddMembership, SocketOptionName.AddMembership }, + { BsdSocketOption.IpDropMembership, SocketOptionName.DropMembership }, + { BsdSocketOption.IpDontFrag, SocketOptionName.DontFragment }, + { BsdSocketOption.IpAddSourceMembership, SocketOptionName.AddSourceMembership }, + { BsdSocketOption.IpDropSourceMembership, SocketOptionName.DropSourceMembership } + }; + + private static readonly Dictionary _tcpSocketOptionMap = new() + { + { BsdSocketOption.TcpNoDelay, SocketOptionName.NoDelay }, + { BsdSocketOption.TcpKeepIdle, SocketOptionName.TcpKeepAliveTime }, + { BsdSocketOption.TcpKeepIntvl, SocketOptionName.TcpKeepAliveInterval }, + { BsdSocketOption.TcpKeepCnt, SocketOptionName.TcpKeepAliveRetryCount } + }; + private bool _isPrivileged; private List _sockets = new List(); @@ -118,13 +162,6 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd private static SocketFlags ConvertBsdSocketFlags(BsdSocketFlags bsdSocketFlags) { - BsdSocketFlags SupportedFlags = - BsdSocketFlags.Oob | - BsdSocketFlags.Peek | - BsdSocketFlags.DontRoute | - BsdSocketFlags.Trunc | - BsdSocketFlags.CTrunc; - SocketFlags socketFlags = SocketFlags.None; if (bsdSocketFlags.HasFlag(BsdSocketFlags.Oob)) @@ -166,6 +203,25 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd return socketFlags; } + private static bool TryConvertSocketOption(BsdSocketOption option, SocketOptionLevel level, out SocketOptionName name) + { + var table = level switch + { + SocketOptionLevel.Socket => _soSocketOptionMap, + SocketOptionLevel.IP => _ipSocketOptionMap, + SocketOptionLevel.Tcp => _tcpSocketOptionMap, + _ => null + }; + + if (table == null) + { + name = default; + return false; + } + + return table.TryGetValue(option, out name); + } + private ResultCode WriteWinSock2Error(ServiceCtx context, WsaError errorCode) { return WriteBsdResult(context, -1, ConvertError(errorCode)); @@ -820,9 +876,9 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd // GetSockOpt(u32 socket, u32 level, u32 option_name) -> (i32 ret, u32 bsd_errno, u32, buffer) public ResultCode GetSockOpt(ServiceCtx context) { - int socketFd = context.RequestData.ReadInt32(); - SocketOptionLevel level = (SocketOptionLevel)context.RequestData.ReadInt32(); - SocketOptionName optionName = (SocketOptionName)context.RequestData.ReadInt32(); + int socketFd = context.RequestData.ReadInt32(); + SocketOptionLevel level = (SocketOptionLevel)context.RequestData.ReadInt32(); + BsdSocketOption option = (BsdSocketOption)context.RequestData.ReadInt32(); (ulong bufferPosition, ulong bufferSize) = context.Request.GetBufferType0x22(); @@ -831,7 +887,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd if (socket != null) { - errno = HandleGetSocketOption(context, socket, optionName, level, bufferPosition, bufferSize); + errno = HandleGetSocketOption(context, socket, option, level, bufferPosition, bufferSize); } return WriteBsdResult(context, 0, errno); @@ -936,45 +992,26 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd private static LinuxError HandleGetSocketOption( ServiceCtx context, BsdSocket socket, - SocketOptionName optionName, + BsdSocketOption option, SocketOptionLevel level, ulong optionValuePosition, ulong optionValueSize) { try { + if (!TryConvertSocketOption(option, level, out SocketOptionName optionName)) + { + Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported GetSockOpt Option: {option} Level: {level}"); + + return LinuxError.EOPNOTSUPP; + } + byte[] optionValue = new byte[optionValueSize]; - switch (optionName) - { - case SocketOptionName.Broadcast: - case SocketOptionName.DontLinger: - case SocketOptionName.Debug: - case SocketOptionName.Error: - case SocketOptionName.KeepAlive: - case SocketOptionName.OutOfBandInline: - case SocketOptionName.ReceiveBuffer: - case SocketOptionName.ReceiveTimeout: - case SocketOptionName.SendBuffer: - case SocketOptionName.SendTimeout: - case SocketOptionName.Type: - case SocketOptionName.Linger: - socket.Handle.GetSocketOption(level, optionName, optionValue); - context.Memory.Write(optionValuePosition, optionValue); + socket.Handle.GetSocketOption(level, optionName, optionValue); + context.Memory.Write(optionValuePosition, optionValue); - return LinuxError.SUCCESS; - - case (SocketOptionName)0x200: - socket.Handle.GetSocketOption(level, SocketOptionName.ReuseAddress, optionValue); - context.Memory.Write(optionValuePosition, optionValue); - - return LinuxError.SUCCESS; - - default: - Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported GetSockOpt OptionName: {optionName}"); - - return LinuxError.EOPNOTSUPP; - } + return LinuxError.SUCCESS; } catch (SocketException exception) { @@ -985,47 +1022,34 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd private static LinuxError HandleSetSocketOption( ServiceCtx context, BsdSocket socket, - SocketOptionName optionName, + BsdSocketOption option, SocketOptionLevel level, ulong optionValuePosition, ulong optionValueSize) { try { - switch (optionName) + if (!TryConvertSocketOption(option, level, out SocketOptionName optionName)) { - case SocketOptionName.Broadcast: - case SocketOptionName.DontLinger: - case SocketOptionName.Debug: - case SocketOptionName.Error: - case SocketOptionName.KeepAlive: - case SocketOptionName.OutOfBandInline: - case SocketOptionName.ReceiveBuffer: - case SocketOptionName.ReceiveTimeout: - case SocketOptionName.SendBuffer: - case SocketOptionName.SendTimeout: - case SocketOptionName.Type: - case SocketOptionName.ReuseAddress: - socket.Handle.SetSocketOption(level, optionName, context.Memory.Read((ulong)optionValuePosition)); + Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported SetSockOpt Option: {option} Level: {level}"); - return LinuxError.SUCCESS; - - case (SocketOptionName)0x200: - socket.Handle.SetSocketOption(level, SocketOptionName.ReuseAddress, context.Memory.Read((ulong)optionValuePosition)); - - return LinuxError.SUCCESS; - - case SocketOptionName.Linger: - socket.Handle.SetSocketOption(level, SocketOptionName.Linger, - new LingerOption(context.Memory.Read((ulong)optionValuePosition) != 0, context.Memory.Read((ulong)optionValuePosition + 4))); - - return LinuxError.SUCCESS; - - default: - Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported SetSockOpt OptionName: {optionName}"); - - return LinuxError.EOPNOTSUPP; + return LinuxError.EOPNOTSUPP; } + + int value = context.Memory.Read((ulong)optionValuePosition); + + if (option == BsdSocketOption.SoLinger) + { + int value2 = context.Memory.Read((ulong)optionValuePosition + 4); + + socket.Handle.SetSocketOption(level, SocketOptionName.Linger, new LingerOption(value != 0, value2)); + } + else + { + socket.Handle.SetSocketOption(level, optionName, value); + } + + return LinuxError.SUCCESS; } catch (SocketException exception) { @@ -1037,9 +1061,9 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd // SetSockOpt(u32 socket, u32 level, u32 option_name, buffer option_value) -> (i32 ret, u32 bsd_errno) public ResultCode SetSockOpt(ServiceCtx context) { - int socketFd = context.RequestData.ReadInt32(); - SocketOptionLevel level = (SocketOptionLevel)context.RequestData.ReadInt32(); - SocketOptionName optionName = (SocketOptionName)context.RequestData.ReadInt32(); + int socketFd = context.RequestData.ReadInt32(); + SocketOptionLevel level = (SocketOptionLevel)context.RequestData.ReadInt32(); + BsdSocketOption option = (BsdSocketOption)context.RequestData.ReadInt32(); (ulong bufferPos, ulong bufferSize) = context.Request.GetBufferType0x21(); @@ -1048,7 +1072,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd if (socket != null) { - errno = HandleSetSocketOption(context, socket, optionName, level, bufferPos, bufferSize); + errno = HandleSetSocketOption(context, socket, option, level, bufferPos, bufferSize); } return WriteBsdResult(context, 0, errno); diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketFlags.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketFlags.cs index 4dc56356..ca464c09 100644 --- a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketFlags.cs +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketFlags.cs @@ -1,5 +1,3 @@ -using System.Net.Sockets; - namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd { enum BsdSocketFlags diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketOption.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketOption.cs new file mode 100644 index 00000000..726e4111 --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketOption.cs @@ -0,0 +1,119 @@ +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + enum BsdSocketOption + { + SoDebug = 0x1, + SoAcceptConn = 0x2, + SoReuseAddr = 0x4, + SoKeepAlive = 0x8, + SoDontRoute = 0x10, + SoBroadcast = 0x20, + SoUseLoopBack = 0x40, + SoLinger = 0x80, + SoOobInline = 0x100, + SoReusePort = 0x200, + SoTimestamp = 0x400, + SoNoSigpipe = 0x800, + SoAcceptFilter = 0x1000, + SoBinTime = 0x2000, + SoNoOffload = 0x4000, + SoNoDdp = 0x8000, + SoReusePortLb = 0x10000, + SoRError = 0x20000, + + SoSndBuf = 0x1001, + SoRcvBuf = 0x1002, + SoSndLoWat = 0x1003, + SoRcvLoWat = 0x1004, + SoSndTimeo = 0x1005, + SoRcvTimeo = 0x1006, + SoError = 0x1007, + SoType = 0x1008, + SoLabel = 0x1009, + SoPeerLabel = 0x1010, + SoListenQLimit = 0x1011, + SoListenQLen = 0x1012, + SoListenIncQLen = 0x1013, + SoSetFib = 0x1014, + SoUserCookie = 0x1015, + SoProtocol = 0x1016, + SoTsClock = 0x1017, + SoMaxPacingRate = 0x1018, + SoDomain = 0x1019, + + IpOptions = 1, + IpHdrIncl = 2, + IpTos = 3, + IpTtl = 4, + IpRecvOpts = 5, + IpRecvRetOpts = 6, + IpRecvDstAddr = 7, + IpRetOpts = 8, + IpMulticastIf = 9, + IpMulticastTtl = 10, + IpMulticastLoop = 11, + IpAddMembership = 12, + IpDropMembership = 13, + IpMulticastVif = 14, + IpRsvpOn = 15, + IpRsvpOff = 16, + IpRsvpVifOn = 17, + IpRsvpVifOff = 18, + IpPortRange = 19, + IpRecvIf = 20, + IpIpsecPolicy = 21, + IpOnesBcast = 23, + IpBindany = 24, + IpBindMulti = 25, + IpRssListenBucket = 26, + IpOrigDstAddr = 27, + + IpFwTableAdd = 40, + IpFwTableDel = 41, + IpFwTableFlush = 42, + IpFwTableGetSize = 43, + IpFwTableList = 44, + + IpFw3 = 48, + IpDummyNet3 = 49, + + IpFwAdd = 50, + IpFwDel = 51, + IpFwFlush = 52, + IpFwZero = 53, + IpFwGet = 54, + IpFwResetLog = 55, + + IpFwNatCfg = 56, + IpFwNatDel = 57, + IpFwNatGetConfig = 58, + IpFwNatGetLog = 59, + + IpDummyNetConfigure = 60, + IpDummyNetDel = 61, + IpDummyNetFlush = 62, + IpDummyNetGet = 64, + + IpRecvTtl = 65, + IpMinTtl = 66, + IpDontFrag = 67, + IpRecvTos = 68, + + IpAddSourceMembership = 70, + IpDropSourceMembership = 71, + IpBlockSource = 72, + IpUnblockSource = 73, + + TcpNoDelay = 1, + TcpMaxSeg = 2, + TcpNoPush = 4, + TcpNoOpt = 8, + TcpMd5Sig = 16, + TcpInfo = 32, + TcpCongestion = 64, + TcpKeepInit = 128, + TcpKeepIdle = 256, + TcpKeepIntvl = 512, + TcpKeepCnt = 1024 + } +} \ No newline at end of file