ryujinx/Ryujinx.Graphics/Shader/Translation/Ssa.cs

330 lines
10 KiB
C#
Raw Normal View History

New shader translator implementation (#654) * Start implementing a new shader translator * Fix shift instructions and a typo * Small refactoring on StructuredProgram, move RemovePhis method to a separate class * Initial geometry shader support * Implement TLD4 * Fix -- There's no negation on FMUL32I * Add constant folding and algebraic simplification optimizations, nits * Some leftovers from constant folding * Avoid cast for constant assignments * Add a branch elimination pass, and misc small fixes * Remove redundant branches, add expression propagation and other improvements on the code * Small leftovers -- add missing break and continue, remove unused properties, other improvements * Add null check to handle empty block cases on block visitor * Add HADD2 and HMUL2 half float shader instructions * Optimize pack/unpack sequences, some fixes related to half float instructions * Add TXQ, TLD, TLDS and TLD4S shader texture instructions, and some support for bindless textures, some refactoring on codegen * Fix copy paste mistake that caused RZ to be ignored on the AST instruction * Add workaround for conditional exit, and fix half float instruction with constant buffer * Add missing 0.0 source for TLDS.LZ variants * Simplify the switch for TLDS.LZ * Texture instructions related fixes * Implement the HFMA instruction, and some misc. fixes * Enable constant folding on UnpackHalf2x16 instructions * Refactor HFMA to use OpCode* for opcode decoding rather than on the helper methods * Remove the old shader translator * Remove ShaderDeclInfo and other unused things * Add dual vertex shader support * Add ShaderConfig, used to pass shader type and maximum cbuffer size * Move and rename some instruction enums * Move texture instructions into a separate file * Move operand GetExpression and locals management to OperandManager * Optimize opcode decoding using a simple list and binary search * Add missing condition for do-while on goto elimination * Misc. fixes on texture instructions * Simplify TLDS switch * Address PR feedback, and a nit
2019-04-17 23:57:08 +00:00
using Ryujinx.Graphics.Shader.Decoders;
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using System.Collections.Generic;
using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper;
namespace Ryujinx.Graphics.Shader.Translation
{
static class Ssa
{
private const int GprsAndPredsCount = RegisterConsts.GprsCount + RegisterConsts.PredsCount;
private class DefMap
{
private Dictionary<Register, Operand> _map;
private long[] _phiMasks;
public DefMap()
{
_map = new Dictionary<Register, Operand>();
_phiMasks = new long[(RegisterConsts.TotalCount + 63) / 64];
}
public bool TryAddOperand(Register reg, Operand operand)
{
return _map.TryAdd(reg, operand);
}
public bool TryGetOperand(Register reg, out Operand operand)
{
return _map.TryGetValue(reg, out operand);
}
public bool AddPhi(Register reg)
{
int key = GetKeyFromRegister(reg);
int index = key / 64;
int bit = key & 63;
long mask = 1L << bit;
if ((_phiMasks[index] & mask) != 0)
{
return false;
}
_phiMasks[index] |= mask;
return true;
}
public bool HasPhi(Register reg)
{
int key = GetKeyFromRegister(reg);
int index = key / 64;
int bit = key & 63;
return (_phiMasks[index] & (1L << bit)) != 0;
}
}
private struct Definition
{
public BasicBlock Block { get; }
public Operand Local { get; }
public Definition(BasicBlock block, Operand local)
{
Block = block;
Local = local;
}
}
public static void Rename(BasicBlock[] blocks)
{
DefMap[] globalDefs = new DefMap[blocks.Length];
for (int blkIndex = 0; blkIndex < blocks.Length; blkIndex++)
{
globalDefs[blkIndex] = new DefMap();
}
Queue<BasicBlock> dfPhiBlocks = new Queue<BasicBlock>();
//First pass, get all defs and locals uses.
for (int blkIndex = 0; blkIndex < blocks.Length; blkIndex++)
{
Operand[] localDefs = new Operand[RegisterConsts.TotalCount];
Operand RenameLocal(Operand operand)
{
if (operand != null && operand.Type == OperandType.Register)
{
Operand local = localDefs[GetKeyFromRegister(operand.GetRegister())];
operand = local ?? operand;
}
return operand;
}
BasicBlock block = blocks[blkIndex];
LinkedListNode<INode> node = block.Operations.First;
while (node != null)
{
if (node.Value is Operation operation)
{
for (int index = 0; index < operation.SourcesCount; index++)
{
operation.SetSource(index, RenameLocal(operation.GetSource(index)));
}
if (operation.Dest != null && operation.Dest.Type == OperandType.Register)
{
Operand local = Local();
localDefs[GetKeyFromRegister(operation.Dest.GetRegister())] = local;
operation.Dest = local;
}
}
node = node.Next;
}
for (int index = 0; index < RegisterConsts.TotalCount; index++)
{
Operand local = localDefs[index];
if (local == null)
{
continue;
}
Register reg = GetRegisterFromKey(index);
globalDefs[block.Index].TryAddOperand(reg, local);
dfPhiBlocks.Enqueue(block);
while (dfPhiBlocks.TryDequeue(out BasicBlock dfPhiBlock))
{
foreach (BasicBlock domFrontier in dfPhiBlock.DominanceFrontiers)
{
if (globalDefs[domFrontier.Index].AddPhi(reg))
{
dfPhiBlocks.Enqueue(domFrontier);
}
}
}
}
}
//Second pass, rename variables with definitions on different blocks.
for (int blkIndex = 0; blkIndex < blocks.Length; blkIndex++)
{
Operand[] localDefs = new Operand[RegisterConsts.TotalCount];
BasicBlock block = blocks[blkIndex];
Operand RenameGlobal(Operand operand)
{
if (operand != null && operand.Type == OperandType.Register)
{
int key = GetKeyFromRegister(operand.GetRegister());
Operand local = localDefs[key];
if (local != null)
{
return local;
}
operand = FindDefinitionForCurr(globalDefs, block, operand.GetRegister());
localDefs[key] = operand;
}
return operand;
}
LinkedListNode<INode> node = block.Operations.First;
while (node != null)
{
if (node.Value is Operation operation)
{
for (int index = 0; index < operation.SourcesCount; index++)
{
operation.SetSource(index, RenameGlobal(operation.GetSource(index)));
}
}
node = node.Next;
}
}
}
private static Operand FindDefinitionForCurr(DefMap[] globalDefs, BasicBlock current, Register reg)
{
if (globalDefs[current.Index].HasPhi(reg))
{
return InsertPhi(globalDefs, current, reg);
}
if (current != current.ImmediateDominator)
{
return FindDefinition(globalDefs, current.ImmediateDominator, reg).Local;
}
return Undef();
}
private static Definition FindDefinition(DefMap[] globalDefs, BasicBlock current, Register reg)
{
foreach (BasicBlock block in SelfAndImmediateDominators(current))
{
DefMap defMap = globalDefs[block.Index];
if (defMap.TryGetOperand(reg, out Operand lastDef))
{
return new Definition(block, lastDef);
}
if (defMap.HasPhi(reg))
{
return new Definition(block, InsertPhi(globalDefs, block, reg));
}
}
return new Definition(current, Undef());
}
private static IEnumerable<BasicBlock> SelfAndImmediateDominators(BasicBlock block)
{
while (block != block.ImmediateDominator)
{
yield return block;
block = block.ImmediateDominator;
}
yield return block;
}
private static Operand InsertPhi(DefMap[] globalDefs, BasicBlock block, Register reg)
{
//This block has a Phi that has not been materialized yet, but that
//would define a new version of the variable we're looking for. We need
//to materialize the Phi, add all the block/operand pairs into the Phi, and
//then use the definition from that Phi.
Operand local = Local();
PhiNode phi = new PhiNode(local);
AddPhi(block, phi);
globalDefs[block.Index].TryAddOperand(reg, local);
foreach (BasicBlock predecessor in block.Predecessors)
{
Definition def = FindDefinition(globalDefs, predecessor, reg);
phi.AddSource(def.Block, def.Local);
}
return local;
}
private static void AddPhi(BasicBlock block, PhiNode phi)
{
LinkedListNode<INode> node = block.Operations.First;
if (node != null)
{
while (node.Next?.Value is PhiNode)
{
node = node.Next;
}
}
if (node?.Value is PhiNode)
{
block.Operations.AddAfter(node, phi);
}
else
{
block.Operations.AddFirst(phi);
}
}
private static int GetKeyFromRegister(Register reg)
{
if (reg.Type == RegisterType.Gpr)
{
return reg.Index;
}
else if (reg.Type == RegisterType.Predicate)
{
return RegisterConsts.GprsCount + reg.Index;
}
else /* if (reg.Type == RegisterType.Flag) */
{
return GprsAndPredsCount + reg.Index;
}
}
private static Register GetRegisterFromKey(int key)
{
if (key < RegisterConsts.GprsCount)
{
return new Register(key, RegisterType.Gpr);
}
else if (key < GprsAndPredsCount)
{
return new Register(key - RegisterConsts.GprsCount, RegisterType.Predicate);
}
else /* if (key < RegisterConsts.TotalCount) */
{
return new Register(key - GprsAndPredsCount, RegisterType.Flag);
}
}
}
}