From edfd4d70c0f38d41c6ebb31508127b14727017bd Mon Sep 17 00:00:00 2001 From: Logan Stromberg Date: Tue, 21 Feb 2023 02:44:57 -0800 Subject: [PATCH] Use SIMD acceleration for audio upsampler (#4410) * Use SIMD acceleration for audio upsampler filter kernel for a moderate speedup * Address formatting. Implement AVX2 fast path for high quality resampling in ResamplerHelper * now really, are we really getting the benefit of inlining 50+ line methods? * adding unit tests for resampler + upsampler. The upsampler ones fail for some reason * Fixing upsampler test. Apparently this algo only works at specific ratios --------- Co-authored-by: Logan Stromberg --- Ryujinx.Audio/Renderer/Dsp/ResamplerHelper.cs | 183 ++++++++++-------- Ryujinx.Audio/Renderer/Dsp/UpsamplerHelper.cs | 23 ++- .../Audio/Renderer/Dsp/ResamplerTests.cs | 93 +++++++++ .../Audio/Renderer/Dsp/UpsamplerTests.cs | 64 ++++++ 4 files changed, 279 insertions(+), 84 deletions(-) create mode 100644 Ryujinx.Tests/Audio/Renderer/Dsp/ResamplerTests.cs create mode 100644 Ryujinx.Tests/Audio/Renderer/Dsp/UpsamplerTests.cs diff --git a/Ryujinx.Audio/Renderer/Dsp/ResamplerHelper.cs b/Ryujinx.Audio/Renderer/Dsp/ResamplerHelper.cs index b46a33fe..7873c4d2 100644 --- a/Ryujinx.Audio/Renderer/Dsp/ResamplerHelper.cs +++ b/Ryujinx.Audio/Renderer/Dsp/ResamplerHelper.cs @@ -1,5 +1,6 @@ using System; using System.Linq; +using System.Numerics; using System.Runtime.CompilerServices; using System.Runtime.Intrinsics; using System.Runtime.Intrinsics.X86; @@ -380,7 +381,6 @@ namespace Ryujinx.Audio.Renderer.Dsp return _normalCurveLut2F; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] private unsafe static void ResampleDefaultQuality(Span outputBuffer, ReadOnlySpan inputBuffer, float ratio, ref float fraction, int sampleCount, bool needPitch) { ReadOnlySpan parameters = GetDefaultParameter(ratio); @@ -394,35 +394,33 @@ namespace Ryujinx.Audio.Renderer.Dsp if (ratio == 1f) { fixed (short* pInput = inputBuffer) + fixed (float* pOutput = outputBuffer, pParameters = parameters) { - fixed (float* pOutput = outputBuffer, pParameters = parameters) + Vector128 parameter = Sse.LoadVector128(pParameters); + + for (; i < (sampleCount & ~3); i += 4) { - Vector128 parameter = Sse.LoadVector128(pParameters); + Vector128 intInput0 = Sse41.ConvertToVector128Int32(pInput + (uint)i); + Vector128 intInput1 = Sse41.ConvertToVector128Int32(pInput + (uint)i + 1); + Vector128 intInput2 = Sse41.ConvertToVector128Int32(pInput + (uint)i + 2); + Vector128 intInput3 = Sse41.ConvertToVector128Int32(pInput + (uint)i + 3); - for (; i < (sampleCount & ~3); i += 4) - { - Vector128 intInput0 = Sse41.ConvertToVector128Int32(pInput + (uint)i); - Vector128 intInput1 = Sse41.ConvertToVector128Int32(pInput + (uint)i + 1); - Vector128 intInput2 = Sse41.ConvertToVector128Int32(pInput + (uint)i + 2); - Vector128 intInput3 = Sse41.ConvertToVector128Int32(pInput + (uint)i + 3); + Vector128 input0 = Sse2.ConvertToVector128Single(intInput0); + Vector128 input1 = Sse2.ConvertToVector128Single(intInput1); + Vector128 input2 = Sse2.ConvertToVector128Single(intInput2); + Vector128 input3 = Sse2.ConvertToVector128Single(intInput3); - Vector128 input0 = Sse2.ConvertToVector128Single(intInput0); - Vector128 input1 = Sse2.ConvertToVector128Single(intInput1); - Vector128 input2 = Sse2.ConvertToVector128Single(intInput2); - Vector128 input3 = Sse2.ConvertToVector128Single(intInput3); + Vector128 mix0 = Sse.Multiply(input0, parameter); + Vector128 mix1 = Sse.Multiply(input1, parameter); + Vector128 mix2 = Sse.Multiply(input2, parameter); + Vector128 mix3 = Sse.Multiply(input3, parameter); - Vector128 mix0 = Sse.Multiply(input0, parameter); - Vector128 mix1 = Sse.Multiply(input1, parameter); - Vector128 mix2 = Sse.Multiply(input2, parameter); - Vector128 mix3 = Sse.Multiply(input3, parameter); + Vector128 mix01 = Sse3.HorizontalAdd(mix0, mix1); + Vector128 mix23 = Sse3.HorizontalAdd(mix2, mix3); - Vector128 mix01 = Sse3.HorizontalAdd(mix0, mix1); - Vector128 mix23 = Sse3.HorizontalAdd(mix2, mix3); + Vector128 mix0123 = Sse3.HorizontalAdd(mix01, mix23); - Vector128 mix0123 = Sse3.HorizontalAdd(mix01, mix23); - - Sse.Store(pOutput + (uint)i, Sse41.RoundToNearestInteger(mix0123)); - } + Sse.Store(pOutput + (uint)i, Sse41.RoundToNearestInteger(mix0123)); } } @@ -431,62 +429,60 @@ namespace Ryujinx.Audio.Renderer.Dsp else { fixed (short* pInput = inputBuffer) + fixed (float* pOutput = outputBuffer, pParameters = parameters) { - fixed (float* pOutput = outputBuffer, pParameters = parameters) + for (; i < (sampleCount & ~3); i += 4) { - for (; i < (sampleCount & ~3); i += 4) - { - uint baseIndex0 = (uint)(fraction * 128) * 4; - uint inputIndex0 = (uint)inputBufferIndex; + uint baseIndex0 = (uint)(fraction * 128) * 4; + uint inputIndex0 = (uint)inputBufferIndex; - fraction += ratio; + fraction += ratio; - uint baseIndex1 = ((uint)(fraction * 128) & 127) * 4; - uint inputIndex1 = (uint)inputBufferIndex + (uint)fraction; + uint baseIndex1 = ((uint)(fraction * 128) & 127) * 4; + uint inputIndex1 = (uint)inputBufferIndex + (uint)fraction; - fraction += ratio; + fraction += ratio; - uint baseIndex2 = ((uint)(fraction * 128) & 127) * 4; - uint inputIndex2 = (uint)inputBufferIndex + (uint)fraction; + uint baseIndex2 = ((uint)(fraction * 128) & 127) * 4; + uint inputIndex2 = (uint)inputBufferIndex + (uint)fraction; - fraction += ratio; + fraction += ratio; - uint baseIndex3 = ((uint)(fraction * 128) & 127) * 4; - uint inputIndex3 = (uint)inputBufferIndex + (uint)fraction; + uint baseIndex3 = ((uint)(fraction * 128) & 127) * 4; + uint inputIndex3 = (uint)inputBufferIndex + (uint)fraction; - fraction += ratio; - inputBufferIndex += (int)fraction; + fraction += ratio; + inputBufferIndex += (int)fraction; - // Only keep lower part (safe as fraction isn't supposed to be negative) - fraction -= (int)fraction; + // Only keep lower part (safe as fraction isn't supposed to be negative) + fraction -= (int)fraction; - Vector128 parameter0 = Sse.LoadVector128(pParameters + baseIndex0); - Vector128 parameter1 = Sse.LoadVector128(pParameters + baseIndex1); - Vector128 parameter2 = Sse.LoadVector128(pParameters + baseIndex2); - Vector128 parameter3 = Sse.LoadVector128(pParameters + baseIndex3); + Vector128 parameter0 = Sse.LoadVector128(pParameters + baseIndex0); + Vector128 parameter1 = Sse.LoadVector128(pParameters + baseIndex1); + Vector128 parameter2 = Sse.LoadVector128(pParameters + baseIndex2); + Vector128 parameter3 = Sse.LoadVector128(pParameters + baseIndex3); - Vector128 intInput0 = Sse41.ConvertToVector128Int32(pInput + inputIndex0); - Vector128 intInput1 = Sse41.ConvertToVector128Int32(pInput + inputIndex1); - Vector128 intInput2 = Sse41.ConvertToVector128Int32(pInput + inputIndex2); - Vector128 intInput3 = Sse41.ConvertToVector128Int32(pInput + inputIndex3); + Vector128 intInput0 = Sse41.ConvertToVector128Int32(pInput + inputIndex0); + Vector128 intInput1 = Sse41.ConvertToVector128Int32(pInput + inputIndex1); + Vector128 intInput2 = Sse41.ConvertToVector128Int32(pInput + inputIndex2); + Vector128 intInput3 = Sse41.ConvertToVector128Int32(pInput + inputIndex3); - Vector128 input0 = Sse2.ConvertToVector128Single(intInput0); - Vector128 input1 = Sse2.ConvertToVector128Single(intInput1); - Vector128 input2 = Sse2.ConvertToVector128Single(intInput2); - Vector128 input3 = Sse2.ConvertToVector128Single(intInput3); + Vector128 input0 = Sse2.ConvertToVector128Single(intInput0); + Vector128 input1 = Sse2.ConvertToVector128Single(intInput1); + Vector128 input2 = Sse2.ConvertToVector128Single(intInput2); + Vector128 input3 = Sse2.ConvertToVector128Single(intInput3); - Vector128 mix0 = Sse.Multiply(input0, parameter0); - Vector128 mix1 = Sse.Multiply(input1, parameter1); - Vector128 mix2 = Sse.Multiply(input2, parameter2); - Vector128 mix3 = Sse.Multiply(input3, parameter3); + Vector128 mix0 = Sse.Multiply(input0, parameter0); + Vector128 mix1 = Sse.Multiply(input1, parameter1); + Vector128 mix2 = Sse.Multiply(input2, parameter2); + Vector128 mix3 = Sse.Multiply(input3, parameter3); - Vector128 mix01 = Sse3.HorizontalAdd(mix0, mix1); - Vector128 mix23 = Sse3.HorizontalAdd(mix2, mix3); + Vector128 mix01 = Sse3.HorizontalAdd(mix0, mix1); + Vector128 mix23 = Sse3.HorizontalAdd(mix2, mix3); - Vector128 mix0123 = Sse3.HorizontalAdd(mix01, mix23); + Vector128 mix0123 = Sse3.HorizontalAdd(mix01, mix23); - Sse.Store(pOutput + (uint)i, Sse41.RoundToNearestInteger(mix0123)); - } + Sse.Store(pOutput + (uint)i, Sse41.RoundToNearestInteger(mix0123)); } } } @@ -526,34 +522,59 @@ namespace Ryujinx.Audio.Renderer.Dsp return _highCurveLut2F; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void ResampleHighQuality(Span outputBuffer, ReadOnlySpan inputBuffer, float ratio, ref float fraction, int sampleCount) + private static unsafe void ResampleHighQuality(Span outputBuffer, ReadOnlySpan inputBuffer, float ratio, ref float fraction, int sampleCount) { ReadOnlySpan parameters = GetHighParameter(ratio); int inputBufferIndex = 0; - // TODO: fast path - - for (int i = 0; i < sampleCount; i++) + if (Avx2.IsSupported) { - int baseIndex = (int)(fraction * 128) * 8; - ReadOnlySpan parameter = parameters.Slice(baseIndex, 8); - ReadOnlySpan currentInput = inputBuffer.Slice(inputBufferIndex, 8); + // Fast path; assumes 256-bit vectors for simplicity because the filter is 8 taps + fixed (short* pInput = inputBuffer) + fixed (float* pParameters = parameters) + { + for (int i = 0; i < sampleCount; i++) + { + int baseIndex = (int)(fraction * 128) * 8; - outputBuffer[i] = (float)Math.Round(currentInput[0] * parameter[0] + - currentInput[1] * parameter[1] + - currentInput[2] * parameter[2] + - currentInput[3] * parameter[3] + - currentInput[4] * parameter[4] + - currentInput[5] * parameter[5] + - currentInput[6] * parameter[6] + - currentInput[7] * parameter[7]); + Vector256 intInput = Avx2.ConvertToVector256Int32(pInput + inputBufferIndex); + Vector256 floatInput = Avx.ConvertToVector256Single(intInput); + Vector256 parameter = Avx.LoadVector256(pParameters + baseIndex); + Vector256 dp = Avx.DotProduct(floatInput, parameter, control: 0xFF); - fraction += ratio; - inputBufferIndex += (int)MathF.Truncate(fraction); + // avx2 does an 8-element dot product piecewise so we have to sum up 2 intermediate results + outputBuffer[i] = (float)Math.Round(dp[0] + dp[4]); - fraction -= (int)fraction; + fraction += ratio; + inputBufferIndex += (int)MathF.Truncate(fraction); + + fraction -= (int)fraction; + } + } + } + else + { + for (int i = 0; i < sampleCount; i++) + { + int baseIndex = (int)(fraction * 128) * 8; + ReadOnlySpan parameter = parameters.Slice(baseIndex, 8); + ReadOnlySpan currentInput = inputBuffer.Slice(inputBufferIndex, 8); + + outputBuffer[i] = (float)Math.Round(currentInput[0] * parameter[0] + + currentInput[1] * parameter[1] + + currentInput[2] * parameter[2] + + currentInput[3] * parameter[3] + + currentInput[4] * parameter[4] + + currentInput[5] * parameter[5] + + currentInput[6] * parameter[6] + + currentInput[7] * parameter[7]); + + fraction += ratio; + inputBufferIndex += (int)MathF.Truncate(fraction); + + fraction -= (int)fraction; + } } } diff --git a/Ryujinx.Audio/Renderer/Dsp/UpsamplerHelper.cs b/Ryujinx.Audio/Renderer/Dsp/UpsamplerHelper.cs index 847acec2..6cdab5a7 100644 --- a/Ryujinx.Audio/Renderer/Dsp/UpsamplerHelper.cs +++ b/Ryujinx.Audio/Renderer/Dsp/UpsamplerHelper.cs @@ -2,6 +2,7 @@ using Ryujinx.Audio.Renderer.Server.Upsampler; using Ryujinx.Common.Memory; using System; using System.Diagnostics; +using System.Numerics; using System.Runtime.CompilerServices; namespace Ryujinx.Audio.Renderer.Dsp @@ -70,16 +71,32 @@ namespace Ryujinx.Audio.Renderer.Dsp return; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] float DoFilterBank(ref UpsamplerBufferState state, in Array20 bank) { float result = 0.0f; Debug.Assert(state.History.Length == HistoryLength); Debug.Assert(bank.Length == FilterBankLength); - for (int j = 0; j < FilterBankLength; j++) + + int curIdx = 0; + if (Vector.IsHardwareAccelerated) { - result += bank[j] * state.History[j]; + // Do SIMD-accelerated block operations where possible. + // Only about a 2x speedup since filter bank length is short + int stopIdx = FilterBankLength - (FilterBankLength % Vector.Count); + while (curIdx < stopIdx) + { + result += Vector.Dot( + new Vector(bank.AsSpan().Slice(curIdx, Vector.Count)), + new Vector(state.History.AsSpan().Slice(curIdx, Vector.Count))); + curIdx += Vector.Count; + } + } + + while (curIdx < FilterBankLength) + { + result += bank[curIdx] * state.History[curIdx]; + curIdx++; } return result; diff --git a/Ryujinx.Tests/Audio/Renderer/Dsp/ResamplerTests.cs b/Ryujinx.Tests/Audio/Renderer/Dsp/ResamplerTests.cs new file mode 100644 index 00000000..364837ee --- /dev/null +++ b/Ryujinx.Tests/Audio/Renderer/Dsp/ResamplerTests.cs @@ -0,0 +1,93 @@ +using NUnit.Framework; +using Ryujinx.Audio.Renderer.Dsp; +using Ryujinx.Audio.Renderer.Parameter; +using Ryujinx.Audio.Renderer.Server.Upsampler; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading.Tasks; + +namespace Ryujinx.Tests.Audio.Renderer.Dsp +{ + class ResamplerTests + { + [Test] + [TestCase(VoiceInParameter.SampleRateConversionQuality.Low)] + [TestCase(VoiceInParameter.SampleRateConversionQuality.Default)] + [TestCase(VoiceInParameter.SampleRateConversionQuality.High)] + public void TestResamplerConsistencyUpsampling(VoiceInParameter.SampleRateConversionQuality quality) + { + DoResamplingTest(44100, 48000, quality); + } + + [Test] + [TestCase(VoiceInParameter.SampleRateConversionQuality.Low)] + [TestCase(VoiceInParameter.SampleRateConversionQuality.Default)] + [TestCase(VoiceInParameter.SampleRateConversionQuality.High)] + public void TestResamplerConsistencyDownsampling(VoiceInParameter.SampleRateConversionQuality quality) + { + DoResamplingTest(48000, 44100, quality); + } + + /// + /// Generates a 1-second sine wave sample at input rate, resamples it to output rate, and + /// ensures that it resampled at the expected rate with no discontinuities + /// + /// The input sample rate to test + /// The output sample rate to test + /// The resampler quality to use + private static void DoResamplingTest(int inputRate, int outputRate, VoiceInParameter.SampleRateConversionQuality quality) + { + float inputSampleRate = (float)inputRate; + float outputSampleRate = (float)outputRate; + int inputSampleCount = inputRate; + int outputSampleCount = outputRate; + short[] inputBuffer = new short[inputSampleCount + 100]; // add some safety buffer at the end + float[] outputBuffer = new float[outputSampleCount + 100]; + for (int sample = 0; sample < inputBuffer.Length; sample++) + { + // 440 hz sine wave with amplitude = 0.5f at input sample rate + inputBuffer[sample] = (short)(32767 * MathF.Sin((440 / inputSampleRate) * (float)sample * MathF.PI * 2f) * 0.5f); + } + + float fraction = 0; + + ResamplerHelper.Resample( + outputBuffer.AsSpan(), + inputBuffer.AsSpan(), + inputSampleRate / outputSampleRate, + ref fraction, + outputSampleCount, + quality, + false); + + float[] expectedOutput = new float[outputSampleCount]; + float sumDifference = 0; + int delay = quality switch + { + VoiceInParameter.SampleRateConversionQuality.High => 3, + VoiceInParameter.SampleRateConversionQuality.Default => 1, + _ => 0 + }; + + for (int sample = 0; sample < outputSampleCount; sample++) + { + outputBuffer[sample] /= 32767; + // 440 hz sine wave with amplitude = 0.5f at output sample rate + expectedOutput[sample] = MathF.Sin((440 / outputSampleRate) * (float)(sample + delay) * MathF.PI * 2f) * 0.5f; + float thisDelta = Math.Abs(expectedOutput[sample] - outputBuffer[sample]); + + // Ensure no discontinuities + Assert.IsTrue(thisDelta < 0.1f); + sumDifference += thisDelta; + } + + sumDifference = sumDifference / (float)outputSampleCount; + // Expect the output to be 99% similar to the expected resampled sine wave + Assert.IsTrue(sumDifference < 0.01f); + } + } +} diff --git a/Ryujinx.Tests/Audio/Renderer/Dsp/UpsamplerTests.cs b/Ryujinx.Tests/Audio/Renderer/Dsp/UpsamplerTests.cs new file mode 100644 index 00000000..2018752b --- /dev/null +++ b/Ryujinx.Tests/Audio/Renderer/Dsp/UpsamplerTests.cs @@ -0,0 +1,64 @@ +using NUnit.Framework; +using Ryujinx.Audio.Renderer.Dsp; +using Ryujinx.Audio.Renderer.Parameter; +using Ryujinx.Audio.Renderer.Server.Upsampler; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading.Tasks; + +namespace Ryujinx.Tests.Audio.Renderer.Dsp +{ + class UpsamplerTests + { + [Test] + public void TestUpsamplerConsistency() + { + UpsamplerBufferState bufferState = new UpsamplerBufferState(); + int inputBlockSize = 160; + int numInputSamples = 32000; + int numOutputSamples = 48000; + float inputSampleRate = numInputSamples; + float outputSampleRate = numOutputSamples; + float[] inputBuffer = new float[numInputSamples + 100]; + float[] outputBuffer = new float[numOutputSamples + 100]; + for (int sample = 0; sample < inputBuffer.Length; sample++) + { + // 440 hz sine wave with amplitude = 0.5f at input sample rate + inputBuffer[sample] = MathF.Sin((440 / inputSampleRate) * (float)sample * MathF.PI * 2f) * 0.5f; + } + + int inputIdx = 0; + int outputIdx = 0; + while (inputIdx + inputBlockSize < numInputSamples) + { + int outputBufLength = (int)Math.Round((float)(inputIdx + inputBlockSize) * outputSampleRate / inputSampleRate) - outputIdx; + UpsamplerHelper.Upsample( + outputBuffer.AsSpan(outputIdx), + inputBuffer.AsSpan(inputIdx), + outputBufLength, + inputBlockSize, + ref bufferState); + + inputIdx += inputBlockSize; + outputIdx += outputBufLength; + } + + float[] expectedOutput = new float[numOutputSamples]; + float sumDifference = 0; + for (int sample = 0; sample < numOutputSamples; sample++) + { + // 440 hz sine wave with amplitude = 0.5f at output sample rate with an offset of 15 + expectedOutput[sample] = MathF.Sin((440 / outputSampleRate) * (float)(sample - 15) * MathF.PI * 2f) * 0.5f; + sumDifference += Math.Abs(expectedOutput[sample] - outputBuffer[sample]); + } + + sumDifference = sumDifference / (float)expectedOutput.Length; + // Expect the output to be 98% similar to the expected resampled sine wave + Assert.IsTrue(sumDifference < 0.02f); + } + } +}