#version 450 core

#extension GL_EXT_shader_8bit_storage : require

layout (local_size_x = 64, local_size_y = 1, local_size_z = 1) in;

layout (std140, set = 0, binding = 0) uniform stride_arguments
{
    ivec4 stride_arguments_data;
};

layout (std430, set = 1, binding = 1) buffer in_s
{
    uint8_t[] in_data;
};

layout (std430, set = 1, binding = 2) buffer out_s
{
    uint8_t[] out_data;
};

void main()
{
    // Determine what slice of the stride copies this invocation will perform.

    int sourceStride = stride_arguments_data.x;
    int targetStride = stride_arguments_data.y;
    int bufferSize = stride_arguments_data.z;
    int sourceOffset = stride_arguments_data.w;

    int strideRemainder = targetStride - sourceStride;
    int invocations = int(gl_WorkGroupSize.x);

    int copiesRequired = bufferSize / sourceStride;

    // Find the copies that this invocation should perform.
    
    // - Copies that all invocations perform.
    int allInvocationCopies = copiesRequired / invocations;

    // - Extra remainder copy that this invocation performs.
    int index = int(gl_LocalInvocationID.x);
    int extra = (index < (copiesRequired % invocations)) ? 1 : 0;

    int copyCount = allInvocationCopies + extra;

    // Finally, get the starting offset. Make sure to count extra copies.

    int startCopy = allInvocationCopies * index + min(copiesRequired % invocations, index);

    int srcOffset = sourceOffset + startCopy * sourceStride;
    int dstOffset = startCopy * targetStride;

    // Perform the copies for this region
    for (int i=0; i<copyCount; i++) {
        for (int j=0; j<sourceStride; j++) {
            out_data[dstOffset++] = in_data[srcOffset++];
        }

        for (int j=0; j<strideRemainder; j++) {
            out_data[dstOffset++] = uint8_t(0);
        }
    }
}