/*******************************************************************************
 * Copyright (c) 2017 Nerian Vision Technologies
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *******************************************************************************/

#include <cstring>
#include <iostream>
#include <limits>
#include <vector>
#include <memory>
#include <algorithm>
#include "visiontransfer/imageprotocol.h"
#include "visiontransfer/alignedallocator.h"
#include "visiontransfer/datablockprotocol.h"
#include "visiontransfer/exceptions.h"

// Network headers
#ifdef _WIN32
    #ifndef NOMINMAX
        #define NOMINMAX
    #endif
    #include <winsock2.h>
#else
    #include <arpa/inet.h>
#endif

// SIMD Headers
#ifdef __AVX2__
#include <immintrin.h>
#elif __SSE2__
#include <emmintrin.h>
#endif

using namespace std;

/*************** Pimpl class containing all private members ***********/

class ImageProtocol::Pimpl {
public:
    Pimpl(ProtocolType protType);

    // Redeclaration of public members
    void setTransferImagePair(const ImagePair& imagePair);
    void setRawTransferData(const ImagePair& metaData, unsigned char* rawData,
        int firstTileWidth = 0, int secondTileWidth = 0, int validBytes = 0x7FFFFFFF);
    void setRawValidBytes(int validBytes);
    const unsigned char* getTransferMessage(int& length);
    bool transferComplete();
    void resetTransfer();
    bool getReceivedImagePair(ImagePair& imagePair);
    bool getPartiallyReceivedImagePair(ImagePair& imagePair,
        int& validRows, bool& complete);
    bool imagesReceived() const;

    unsigned char* getNextReceiveBuffer(int& maxLength);

    bool processReceivedMessage(int length);
    int getProspectiveMessageSize();
    void resetReception();

private:
    // Header data transferred in the first packet
#pragma pack(push,1)
    struct HeaderData{
        unsigned char protocolVersion;
        unsigned char padding0;

        unsigned short width;
        unsigned short height;

        unsigned short firstTileWidth;
        unsigned short secondTileWidth;

        unsigned char format0;
        unsigned char format1;
        unsigned char minDisparity;
        unsigned char maxDisparity;

        unsigned int seqNum;
        int timeSec;
        int timeMicrosec;

        float q[16];

        unsigned char padding1[6]; // Pad to 32 bytes
    };
#pragma pack(pop)

    static const unsigned char CURRENT_VERSION = 0x04;

    // Underlying protocol for data transfers
    DataBlockProtocol dataProt;
    ProtocolType protType;

    // Transfer related variables
    bool headerTransferred;
    std::vector<unsigned char> headerBuffer;
    std::vector<unsigned char> rawBuffer;
    unsigned char* rawData;
    int rawValidBytes;
    int rawDataLength;

    // Reception related variables
    std::vector<unsigned char, AlignedAllocator<unsigned char> >decodeBuffer[2];
    bool receiveHeaderParsed;
    HeaderData receiveHeader;
    int lastReceivedPayloadBytes[2];
    int receiveTotalSize;
    bool receptionDone;

    // Copies the transmission header to the given buffer
    void copyHeaderToBuffer(const ImagePair& imagePair, int firstTileWidth,
        int secondTileWidth, unsigned char* buffer);

    // Decodes header information from the received data
    void tryDecodeHeader(const unsigned char* receivedData, int receivedBytes);

    // Decodes a received image from an interleaved buffer
    unsigned char* decodeInterleaved(int imageNumber, int receivedBytes,
        unsigned char* data, int& validRows, int& rowStride);

    // Decodes the 12-bit disparity map into 16-bit values
    void decodeSubpixel(int startRow, int stopRow, unsigned const char* src,
        unsigned char* dst, int srcStride, int dstStride, int rowWidth);

    // Various implementations of decodeSubpixel()
    template <bool alignedLoad>
    void decodeSubpixelSSE2(int startRow, int stopRow, const unsigned char* dispStart,
        const unsigned char* subpixStart,  int width, unsigned short* dst, int srcStride, int dstStride);

    template <bool alignedLoad>
    void decodeSubpixelAVX2(int startRow, int stopRow, const unsigned char* dispStart,
        const unsigned char* subpixStart,  int width, unsigned short* dst, int srcStride, int dstStride);

    void decodeSubpixelFallback(int startRow, int stopRow, const unsigned char* dispStart,
        const unsigned char* subpixStart, int width, unsigned short* dst, int srcStride, int dstStride);

    int getFrameSize(int width, int height, int firstTileWidth, int secondTileWidth,
        ImagePair::ImageFormat format0, ImagePair::ImageFormat format1, int headerSize);

    int getFormatNibbles(ImagePair::ImageFormat format);

    void decodeTiledImage(int imageNumber, int lastReceivedPayloadBytes, int receivedPayloadBytes,
        const unsigned char* data, int firstTileStride, int secondTileStride, int& validRows,
        ImagePair::ImageFormat format);

    void decodeRowsFromTile(int startRow, int stopRow, unsigned const char* src,
        unsigned char* dst, int srcStride, int dstStride, int tileWidth);

    void allocateDecodeBuffer(int imageNumber);
};


/******************** Stubs for all public members ********************/

ImageProtocol::ImageProtocol(ProtocolType protType)
    : pimpl(new Pimpl(protType)) {
    // All initializations are done by the Pimpl class
}

ImageProtocol::~ImageProtocol() {
    delete pimpl;
}

void ImageProtocol::setTransferImagePair(const ImagePair& imagePair) {
    pimpl->setTransferImagePair(imagePair);
}

void ImageProtocol::setRawTransferData(const ImagePair& metaData,
        unsigned char* imageData, int firstTileWidth, int secondTileWidth, int validBytes) {
    pimpl->setRawTransferData(metaData, imageData, firstTileWidth, secondTileWidth, validBytes);
}

void ImageProtocol::setRawValidBytes(int validBytes) {
    pimpl->setRawValidBytes(validBytes);
}

const unsigned char* ImageProtocol::getTransferMessage(int& length) {
    return pimpl->getTransferMessage(length);
}

bool ImageProtocol::transferComplete() {
    return pimpl->transferComplete();
}

void ImageProtocol::resetTransfer() {
    pimpl->resetTransfer();
}

bool ImageProtocol::getReceivedImagePair(ImagePair& imagePair) {
    return pimpl->getReceivedImagePair(imagePair);
}

bool ImageProtocol::getPartiallyReceivedImagePair(
        ImagePair& imagePair, int& validRows, bool& complete) {
    return pimpl->getPartiallyReceivedImagePair(imagePair, validRows, complete);
}

bool ImageProtocol::imagesReceived() const {
    return pimpl->imagesReceived();
}

unsigned char* ImageProtocol::getNextReceiveBuffer(int& maxLength) {
    return pimpl->getNextReceiveBuffer(maxLength);
}

bool ImageProtocol::processReceivedMessage(int length) {
    return pimpl->processReceivedMessage(length);
}

void ImageProtocol::resetReception() {
    pimpl->resetReception();
}

/******************** Implementation in pimpl class *******************/

ImageProtocol::Pimpl::Pimpl(ProtocolType protType)
        :dataProt(static_cast<DataBlockProtocol::ProtocolType>(protType)),
        protType(protType), headerTransferred(false),
        rawData(nullptr), rawValidBytes(0), rawDataLength(0), receiveHeaderParsed(false),
        lastReceivedPayloadBytes{0, 0}, receiveTotalSize(0), receptionDone(false) {
    headerBuffer.resize(sizeof(HeaderData) + sizeof(unsigned short));
    memset(&headerBuffer[0], 0, sizeof(headerBuffer.size()));
    memset(&receiveHeader, 0, sizeof(receiveHeader));

    // Just after start-up we don't yet know the expected data size. So lets
    // just allocate enough memory for one UDP packet
    dataProt.setReceiveDataSize(DataBlockProtocol::MAX_UDP_BYTES_TRANSFER);
}

void ImageProtocol::Pimpl::setTransferImagePair(const ImagePair& imagePair) {
    if(imagePair.getPixelData(0) == nullptr || imagePair.getPixelData(1) == nullptr) {
        throw ProtocolException("Image data is null pointer!");
    }

    headerTransferred = false;

    // Set header as first piece of data
    copyHeaderToBuffer(imagePair, 0, 0, &headerBuffer[0]);
    dataProt.startTransfer();
    dataProt.setTransferData(&headerBuffer[0], sizeof(HeaderData));

    // Make an interleaved copy of both images
    int bytes0 = imagePair.getPixelFormat(0) == ImagePair::FORMAT_8_BIT ? 1 : 2;
    int bytes1 = imagePair.getPixelFormat(1) == ImagePair::FORMAT_8_BIT ? 1 : 2;

    rawBuffer.resize(imagePair.getWidth()*imagePair.getHeight()*(bytes0 + bytes1) + sizeof(short));

    int bufferOffset = 0;
    int row0Size = imagePair.getWidth()*bytes0;
    int row1Size = imagePair.getWidth()*bytes1;
    for(int y = 0; y<imagePair.getHeight(); y++) {
        memcpy(&rawBuffer[bufferOffset], &imagePair.getPixelData(0)[y*imagePair.getRowStride(0)], row0Size);
        bufferOffset += row0Size;

        memcpy(&rawBuffer[bufferOffset], &imagePair.getPixelData(1)[y*imagePair.getRowStride(1)], row1Size);
        bufferOffset += row1Size;
    }

    rawData = &rawBuffer[0];
    rawValidBytes = static_cast<int>(rawBuffer.size() - sizeof(short));

    rawDataLength = getFrameSize(imagePair.getWidth(), imagePair.getHeight(), 0, 0,
        imagePair.getPixelFormat(0), imagePair.getPixelFormat(1), 0);
}

void ImageProtocol::Pimpl::setRawTransferData(const ImagePair& metaData, unsigned char* rawData,
        int firstTileWidth, int secondTileWidth, int validBytes) {
    if(rawData == nullptr) {
        throw ProtocolException("Image data is null pointer!");
    } else if(metaData.getPixelFormat(0) != ImagePair::FORMAT_8_BIT) {
        throw ProtocolException("First image must have 8-bit depth!");
    }

    headerTransferred = false;

    // Set header as first piece of data
    copyHeaderToBuffer(metaData, firstTileWidth, secondTileWidth, &headerBuffer[0]);
    dataProt.startTransfer();
    dataProt.setTransferData(&headerBuffer[0], sizeof(HeaderData));

    this->rawData = rawData;
    rawValidBytes = validBytes;

    rawDataLength = getFrameSize(metaData.getWidth(), metaData.getHeight(),
        firstTileWidth, secondTileWidth, metaData.getPixelFormat(0),
        metaData.getPixelFormat(1), 0);
}

void ImageProtocol::Pimpl::setRawValidBytes(int validBytes) {
    rawValidBytes = validBytes;
    if(headerTransferred) {
        dataProt.setTransferValidBytes(validBytes);
    }
}

const unsigned char* ImageProtocol::Pimpl::getTransferMessage(int& length) {
    const unsigned char* msg = dataProt.getTransferMessage(length);

    if(msg == nullptr && !headerTransferred && rawValidBytes > 0) {
        // Transmitting the header is complete. Lets transfer the actual
        // payload.
        headerTransferred = true;
        dataProt.setTransferData(rawData, rawDataLength, rawValidBytes);
        msg = dataProt.getTransferMessage(length);
    }

    return msg;
}

bool ImageProtocol::Pimpl::transferComplete() {
    return dataProt.transferComplete() && headerTransferred;
}

int ImageProtocol::Pimpl::getFrameSize(int width, int height, int firstTileWidth,
        int secondTileWidth, ImagePair::ImageFormat format0,
        ImagePair::ImageFormat format1, int headerSize) {
    int nibbles0 = format0 == ImagePair::FORMAT_8_BIT ? 2 : 3;
    int nibbles1 = format1 == ImagePair::FORMAT_8_BIT ? 2 : 3;

    int effectiveWidth = firstTileWidth > 0 ? firstTileWidth + secondTileWidth : width;

    return (effectiveWidth * height * (nibbles0 + nibbles1)) /2 + headerSize;
}

int ImageProtocol::Pimpl::getFormatNibbles(ImagePair::ImageFormat format) {
    // A nibble is 4 bits
    if(format == ImagePair::FORMAT_12_BIT) {
        return 3;
    } else {
        return 2;
    }
}

void ImageProtocol::Pimpl::copyHeaderToBuffer(const ImagePair& imagePair,
        int firstTileWidth, int secondTileWidth, unsigned char* buffer) {
    HeaderData* transferHeader = reinterpret_cast<HeaderData*>(buffer);
    memset(transferHeader, 0, sizeof(*transferHeader));
    transferHeader->protocolVersion = CURRENT_VERSION;
    transferHeader->width = htons(imagePair.getWidth());
    transferHeader->height = htons(imagePair.getHeight());
    transferHeader->firstTileWidth = htons(firstTileWidth);
    transferHeader->secondTileWidth = htons(secondTileWidth);
    transferHeader->format0 = static_cast<unsigned char>(imagePair.getPixelFormat(0));
    transferHeader->format1 = static_cast<unsigned char>(imagePair.getPixelFormat(1));
    transferHeader->seqNum = static_cast<unsigned int>(htonl(imagePair.getSequenceNumber()));

    int minDisp = 0, maxDisp = 0;
    imagePair.getDisparityRange(minDisp, maxDisp);
    transferHeader->minDisparity = minDisp;
    transferHeader->maxDisparity = maxDisp;

    int timeSec = 0, timeMicrosec = 0;
    imagePair.getTimestamp(timeSec, timeMicrosec);
    transferHeader->timeSec = static_cast<int>(htonl(static_cast<unsigned int>(timeSec)));
    transferHeader->timeMicrosec = static_cast<int>(htonl(static_cast<unsigned int>(timeMicrosec)));

    if(imagePair.getQMatrix() != nullptr) {
        memcpy(transferHeader->q, imagePair.getQMatrix(), sizeof(float)*16);
    }
}

void ImageProtocol::Pimpl::resetTransfer() {
    dataProt.resetTransfer();
}

unsigned char* ImageProtocol::Pimpl::getNextReceiveBuffer(int& maxLength) {
    maxLength = dataProt.getMaxPayloadSize() + dataProt.getProtocolOverhead();
    return dataProt.getNextReceiveBuffer(maxLength);
}

bool ImageProtocol::Pimpl::processReceivedMessage(int length) {
    receptionDone = false;

    // Add the received message
    if(!dataProt.processReceivedMessage(length)) {
        resetReception();
        return false;
    }

    int receivedBytes = 0;
    const unsigned char* receivedData = dataProt.getReceivedData(receivedBytes);

    // Immediately try to decode the header
    if(!receiveHeaderParsed && receivedBytes > 0) {
        tryDecodeHeader(receivedData, receivedBytes);
    }

    // Check if we have received a complete frame
    if(receivedBytes == receiveTotalSize) {
        receptionDone = true;
        dataProt.finishReception();
    } else if(receivedBytes > receiveTotalSize) {
        // This is a corrupted frame
        dataProt.resetReception();
        return false;
    }

    return true;
}

void ImageProtocol::Pimpl::tryDecodeHeader(const
unsigned char* receivedData, int receivedBytes) {
    if(receivedBytes >= static_cast<int>(sizeof(HeaderData))) {
        receiveHeader =  *reinterpret_cast<const HeaderData*>(receivedData);

        if(receiveHeader.protocolVersion > CURRENT_VERSION ||
                receiveHeader.protocolVersion < 4) {
            throw ProtocolException("Protocol version mismatch!");
        }

        // Convert byte order
        receiveHeader.width = ntohs(receiveHeader.width);
        receiveHeader.height = ntohs(receiveHeader.height);
        receiveHeader.firstTileWidth = ntohs(receiveHeader.firstTileWidth);
        receiveHeader.secondTileWidth = ntohs(receiveHeader.secondTileWidth);
        receiveHeader.timeSec = static_cast<int>(
            htonl(static_cast<unsigned int>(receiveHeader.timeSec)));
        receiveHeader.timeMicrosec = static_cast<int>(
            htonl(static_cast<unsigned int>(receiveHeader.timeMicrosec)));
        receiveHeader.seqNum = htonl(receiveHeader.seqNum);

        // Make sure that the receive buffer is large enough
        receiveTotalSize = getFrameSize(
            receiveHeader.width,
            receiveHeader.height,
            receiveHeader.firstTileWidth,
            receiveHeader.secondTileWidth,
            static_cast<ImagePair::ImageFormat>(receiveHeader.format0),
            static_cast<ImagePair::ImageFormat>(receiveHeader.format1),
            sizeof(HeaderData));

        dataProt.setReceiveDataSize(receiveTotalSize);
        receiveHeaderParsed = true;
    }
}

bool ImageProtocol::Pimpl::imagesReceived() const {
    return receptionDone && receiveHeaderParsed;
}

bool ImageProtocol::Pimpl::getReceivedImagePair(ImagePair& imagePair) {
    bool complete = false;
    int validRows;
    bool ok = getPartiallyReceivedImagePair(imagePair, validRows, complete);

    return (ok && complete);
}

bool ImageProtocol::Pimpl::getPartiallyReceivedImagePair(ImagePair& imagePair, int& validRows, bool& complete) {
    imagePair.setWidth(0);
    imagePair.setHeight(0);

    complete = false;

    if(!receiveHeaderParsed) {
        // We haven't even received the image header yet
        return false;
    } else {
        // We received at least some pixel data
        int receivedBytes = 0;
        unsigned char* data = dataProt.getReceivedData(receivedBytes);
        if(receivedBytes == receiveTotalSize) {
            // Receiving this frame has completed
            dataProt.finishReception();
        }

        validRows = 0;
        imagePair.setWidth(receiveHeader.width);
        imagePair.setHeight(receiveHeader.height);
        imagePair.setPixelFormat(0, static_cast<ImagePair::ImageFormat>(receiveHeader.format0));
        imagePair.setPixelFormat(1, static_cast<ImagePair::ImageFormat>(receiveHeader.format1));

        int rowStride0 = 0, rowStride1 = 0;
        int validRows0 = 0, validRows1 = 0;
        unsigned char* pixel0 = decodeInterleaved(0, receivedBytes, data, validRows0, rowStride0);
        unsigned char* pixel1 = decodeInterleaved(1, receivedBytes, data, validRows1, rowStride1);

        imagePair.setRowStride(0, rowStride0);
        imagePair.setRowStride(1, rowStride1);
        imagePair.setPixelData(0, pixel0);
        imagePair.setPixelData(1, pixel1);
        imagePair.setQMatrix(receiveHeader.q);

        imagePair.setSequenceNumber(receiveHeader.seqNum);
        imagePair.setTimestamp(receiveHeader.timeSec, receiveHeader.timeMicrosec);
        imagePair.setDisparityRange(receiveHeader.minDisparity, receiveHeader.maxDisparity);

        validRows = min(validRows0, validRows1);

        if(validRows == receiveHeader.height) {
            complete = true;
        }

        if(receptionDone) {
            // Reset everything for receiving the next image
            resetReception();
        }

        return true;
    }
}

unsigned char* ImageProtocol::Pimpl::decodeInterleaved(int imageNumber, int receivedBytes,
        unsigned char* data, int& validRows, int& rowStride) {
    if(receivedBytes <= static_cast<int>(sizeof(HeaderData))) {
        // We haven't yet received any data for the requested image
        return nullptr;
    }

    ImagePair::ImageFormat format = static_cast<ImagePair::ImageFormat>(
        imageNumber == 0 ? receiveHeader.format0 : receiveHeader.format1);
    int nibbles0 = getFormatNibbles(static_cast<ImagePair::ImageFormat>(receiveHeader.format0));
    int nibbles1 = getFormatNibbles(static_cast<ImagePair::ImageFormat>(receiveHeader.format1));

    unsigned char* ret = nullptr;
    int payloadBytes = receivedBytes - sizeof(HeaderData);

    if(receiveHeader.secondTileWidth == 0) {
        int bufferOffset = sizeof(HeaderData) + imageNumber*receiveHeader.width * nibbles0/2;
        int bufferRowStride = receiveHeader.width*(nibbles0 + nibbles1) / 2;

        if(format == ImagePair::FORMAT_12_BIT) {
            // Perform 12-bit => 16 bit decoding
            allocateDecodeBuffer(imageNumber);
            validRows = payloadBytes / bufferRowStride;
            rowStride = 2*receiveHeader.width;
            int lastRow = lastReceivedPayloadBytes[imageNumber] / bufferRowStride;

            decodeSubpixel(lastRow, validRows, &data[bufferOffset],
                &decodeBuffer[imageNumber][0], bufferRowStride, rowStride, receiveHeader.width);
            ret = &decodeBuffer[imageNumber][0];
        } else {
            // No decoding is neccessary. We can just pass through the
            // data pointer
            ret = &data[bufferOffset];
            rowStride = bufferRowStride;
            validRows = payloadBytes / bufferRowStride;
        }
    } else {
        // Decode the tiled transfer
        decodeTiledImage(imageNumber,
            lastReceivedPayloadBytes[imageNumber], payloadBytes,
            data, receiveHeader.firstTileWidth * (nibbles0 + nibbles1) / 2,
            receiveHeader.secondTileWidth * (nibbles0 + nibbles1) / 2,
            validRows, format);
        ret = &decodeBuffer[imageNumber][0];

        if(format == ImagePair::FORMAT_12_BIT) {
            rowStride = 2*receiveHeader.width;
        } else {
            rowStride = receiveHeader.width;
        }
    }

    lastReceivedPayloadBytes[imageNumber] = payloadBytes;
    return ret;
}

void ImageProtocol::Pimpl::allocateDecodeBuffer(int imageNumber) {
    ImagePair::ImageFormat format = static_cast<ImagePair::ImageFormat>(
        imageNumber == 0 ? receiveHeader.format0 : receiveHeader.format1);
    int bytesPerPixel = (format == ImagePair::FORMAT_12_BIT ? 2 : 1);
    int bufferSize = receiveHeader.width * receiveHeader.height * bytesPerPixel;

    if(decodeBuffer[imageNumber].size() != static_cast<unsigned int>(bufferSize)) {
        decodeBuffer[imageNumber].resize(bufferSize);
    }
}

void ImageProtocol::Pimpl::decodeTiledImage(int imageNumber, int lastReceivedPayloadBytes, int receivedPayloadBytes,
        const unsigned char* data, int firstTileStride, int secondTileStride, int& validRows,
        ImagePair::ImageFormat format) {

    // Allocate a decoding buffer
    allocateDecodeBuffer(imageNumber);

    // Get beginning and end of first tile
    int startFirstTile = lastReceivedPayloadBytes / firstTileStride;
    int stopFirstTile = std::min(receivedPayloadBytes / firstTileStride,
        static_cast<int>(receiveHeader.height));

    // Get beginning and end of second tile
    int secondTileBytes = receivedPayloadBytes - (receiveHeader.height*firstTileStride);
    int lastSecondTileBytes = lastReceivedPayloadBytes - (receiveHeader.height*firstTileStride);
    int startSecondTile = std::max(0, lastSecondTileBytes / secondTileStride);
    int stopSecondTile = std::max(0, secondTileBytes / secondTileStride);
    int firstTileOffset = sizeof(HeaderData) + imageNumber * getFormatNibbles(
        static_cast<ImagePair::ImageFormat>(receiveHeader.format0)) * receiveHeader.firstTileWidth / 2;

    // Decode first tile
    if(format == ImagePair::FORMAT_12_BIT) {
        decodeSubpixel(startFirstTile, stopFirstTile, &data[firstTileOffset], &decodeBuffer[imageNumber][0],
            firstTileStride, 2*receiveHeader.width, receiveHeader.firstTileWidth);
    } else {
        decodeRowsFromTile(startFirstTile, stopFirstTile, &data[firstTileOffset],
            &decodeBuffer[imageNumber][0], firstTileStride, receiveHeader.width,
            receiveHeader.firstTileWidth);
    }

    // Decode second tile
    int secondTileOffset = sizeof(HeaderData) + receiveHeader.height*firstTileStride +
        imageNumber * getFormatNibbles(static_cast<ImagePair::ImageFormat>(receiveHeader.format0)) * receiveHeader.secondTileWidth / 2;

    if(format == ImagePair::FORMAT_12_BIT) {
        decodeSubpixel(startSecondTile, stopSecondTile,
            &data[secondTileOffset], &decodeBuffer[imageNumber][2*receiveHeader.firstTileWidth],
            secondTileStride, 2*receiveHeader.width, receiveHeader.secondTileWidth);
    } else {
        decodeRowsFromTile(startSecondTile, stopSecondTile, &data[secondTileOffset],
            &decodeBuffer[imageNumber][receiveHeader.firstTileWidth],
            secondTileStride, receiveHeader.width, receiveHeader.secondTileWidth);
    }

    validRows = stopSecondTile;
}

void ImageProtocol::Pimpl::decodeRowsFromTile(int startRow, int stopRow, unsigned const char* src,
        unsigned char* dst, int srcStride, int dstStride, int tileWidth) {
    for(int y = startRow; y < stopRow; y++) {
        memcpy(&dst[y*dstStride], &src[y*srcStride], tileWidth);
    }
}

void ImageProtocol::Pimpl::resetReception() {
    receiveHeaderParsed = false;
    lastReceivedPayloadBytes[0] = 0;
    lastReceivedPayloadBytes[1] = 0;
    receiveTotalSize = 0;
    dataProt.resetReception();
    receptionDone = false;
}

void ImageProtocol::Pimpl::decodeSubpixel(int startRow, int stopRow, unsigned const char* src,
        unsigned char* dst, int srcStride, int dstStride, int rowWidth) {

    const unsigned char* dispStart = src;
    const unsigned char* subpixStart = &src[rowWidth];

#   ifdef __AVX2__
    if(rowWidth % 32 == 0) {
        if(srcStride % 32 == 0 && reinterpret_cast<size_t>(src) % 32 == 0) {
            decodeSubpixelAVX2<true>(startRow, stopRow, dispStart, subpixStart,
                rowWidth, reinterpret_cast<unsigned short*>(dst), srcStride, dstStride);
        } else {
            decodeSubpixelAVX2<false>(startRow, stopRow, dispStart, subpixStart,
                rowWidth, reinterpret_cast<unsigned short*>(dst), srcStride, dstStride);
        }
    } else // We use the SSSE implementation as fall back if the image width is not
           // dividable by 32
#   endif
#   ifdef __SSE2__
    if(rowWidth % 16 == 0) {
        if(srcStride % 16 == 0 && reinterpret_cast<size_t>(src) % 16 == 0) {
            decodeSubpixelSSE2<true>(startRow, stopRow, dispStart, subpixStart,
                rowWidth, reinterpret_cast<unsigned short*>(dst), srcStride, dstStride);
        } else {
            decodeSubpixelSSE2<false>(startRow, stopRow, dispStart, subpixStart,
                rowWidth, reinterpret_cast<unsigned short*>(dst), srcStride, dstStride);
        }
    } else // We use the SSSE implementation as fall back if the image width is not
           // dividable by 32
#   endif
    {
        decodeSubpixelFallback(startRow, stopRow, dispStart, subpixStart, rowWidth,
            reinterpret_cast<unsigned short*>(dst), srcStride, dstStride);
    }
}

#ifdef __SSE2__
template <bool alignedLoad>
void ImageProtocol::Pimpl::decodeSubpixelSSE2(int startRow, int stopRow, const unsigned char* dispStart,
        const unsigned char* subpixStart, int width, unsigned short* dst, int srcStride, int dstStride) {
    if(width % 16 != 0) {
        throw ProtocolException("Image width must be a multiple of 16!");
    }

    // SSE optimized code
    __m128i zero = _mm_set1_epi8(0x00);
    __m128i subpixMask = _mm_set1_epi8(0x0f);
    unsigned char* outPos = &reinterpret_cast<unsigned char*>(dst)[startRow*dstStride];
    int outRowPadding = dstStride - 2*width;

    for(int y = startRow; y<stopRow; y++) {
        const unsigned char* intPos = &dispStart[y*srcStride];
        const unsigned char* intEndPos = &dispStart[y*srcStride + width];
        const unsigned char* subpixPos = &subpixStart[y*srcStride];

        for(; intPos < intEndPos;) {
            // Get subpix offsets
            __m128i subpixOffsets;
            if(alignedLoad) {
                subpixOffsets = _mm_load_si128(reinterpret_cast<const __m128i*>(subpixPos));
            } else {
                subpixOffsets = _mm_loadu_si128(reinterpret_cast<const __m128i*>(subpixPos));
            }
            subpixPos += 16;

            __m128i offsetsEven = _mm_and_si128(subpixOffsets, subpixMask);
            __m128i offsetsUneven = _mm_and_si128(_mm_srli_epi16(subpixOffsets, 4), subpixMask);

            for(int i=0; i<2; i++) {
                // Load integer disparities
                __m128i intDisps;
                if(alignedLoad) {
                    intDisps = _mm_load_si128(reinterpret_cast<const __m128i*>(intPos));
                } else {
                    intDisps = _mm_loadu_si128(reinterpret_cast<const __m128i*>(intPos));
                }

                intPos += 16;

                // Get integer disparities shifted by 4
                __m128i disps1 = _mm_slli_epi16(_mm_unpacklo_epi8(intDisps, zero), 4);
                __m128i disps2 = _mm_slli_epi16(_mm_unpackhi_epi8(intDisps, zero), 4);

                // Unpack subpixel offsets for selected disparities
                __m128i offsets;
                if(i == 0) {
                    offsets = _mm_unpacklo_epi8(offsetsEven, offsetsUneven);
                } else  {
                    offsets = _mm_unpackhi_epi8(offsetsEven, offsetsUneven);
                }

                // Add subpixel offsets to integer disparities
                disps1 = _mm_or_si128(disps1, _mm_unpacklo_epi8(offsets, zero));
                disps2 = _mm_or_si128(disps2, _mm_unpackhi_epi8(offsets, zero));

                // Store result
                _mm_store_si128(reinterpret_cast<__m128i*>(outPos), disps1);
                outPos += 16;
                _mm_store_si128(reinterpret_cast<__m128i*>(outPos), disps2);
                outPos += 16;

                if(!alignedLoad && intPos >= intEndPos) {
                    // In the non-aligned case we might need one iteration less
                    break;
                }
            }
        }

        outPos += outRowPadding;
    }
}
#endif

# ifdef __AVX2__
template <bool alignedLoad>
void ImageProtocol::Pimpl::decodeSubpixelAVX2(int startRow, int stopRow, const unsigned char* dispStart,
        const unsigned char* subpixStart, int width, unsigned short* dst, int srcStride, int dstStride) {
    if(width % 32 != 0) {
        // We use the SSE implementation as fall back if the image size isn't
        // a multiple of
        throw ProtocolException("Image width must be a multiple of 32!");
    }

    // AVX2 optimized code
    __m256i zero = _mm256_set1_epi8(0x00);
    __m256i subpixMask = _mm256_set1_epi8(0x0f);
    unsigned char* outPos = &reinterpret_cast<unsigned char*>(dst)[startRow*dstStride];
    int outRowPadding = dstStride - 2*width;

    for(int y = startRow; y<stopRow; y++) {
        const unsigned char* intPos = &dispStart[y*srcStride];
        const unsigned char* intEndPos = &dispStart[y*srcStride + width];
        const unsigned char* subpixPos = &subpixStart[y*srcStride];

        for(; intPos < intEndPos;) {
            // Get subpix offsets
            __m256i subpixOffsets;

            if(alignedLoad) {
                subpixOffsets = _mm256_load_si256(reinterpret_cast<const __m256i*>(subpixPos));
            } else {
                subpixOffsets = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(subpixPos));
            }
            subpixPos += 32;

            __m256i offsetsEven = _mm256_and_si256(subpixOffsets, subpixMask);
            __m256i offsetsUneven = _mm256_and_si256(_mm256_srli_epi16 (subpixOffsets, 4), subpixMask);

            for(int i=0; i<2; i++) {
                // Load integer disparities
                __m256i intDisps;
                if(alignedLoad) {
                    intDisps = _mm256_load_si256(reinterpret_cast<const __m256i*>(intPos));
                } else {
                    intDisps = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(intPos));
                }
                intPos += 32;

                // Stupid AVX2 unpack mixes everything up! Lets swap the register beforehand.
                __m256i intDispsMixup = _mm256_permute4x64_epi64(intDisps, 0xd8);

                // Get integer disparities shifted by 4
                __m256i disps1 = _mm256_slli_epi16(_mm256_unpacklo_epi8(intDispsMixup, zero), 4);
                __m256i disps2 = _mm256_slli_epi16(_mm256_unpackhi_epi8(intDispsMixup, zero), 4);

                // Unpack swap again :-(
                __m256i offsetsEvenMixup = _mm256_permute4x64_epi64(offsetsEven, 0xd8);
                __m256i offsetsUnevenMixup = _mm256_permute4x64_epi64(offsetsUneven, 0xd8);

                // Unpack subpixel offsets for selected disparities
                __m256i offsets;
                if(i == 0) {
                    offsets = _mm256_unpacklo_epi8(offsetsEvenMixup, offsetsUnevenMixup);
                } else  {
                    offsets = _mm256_unpackhi_epi8(offsetsEvenMixup, offsetsUnevenMixup);
                }

                // And again!!
                __m256i offsetsMixup = _mm256_permute4x64_epi64(offsets, 0xd8);

                // Add subpixel offsets to integer disparities
                disps1 = _mm256_or_si256(disps1, _mm256_unpacklo_epi8(offsetsMixup, zero));
                disps2 = _mm256_or_si256(disps2, _mm256_unpackhi_epi8(offsetsMixup, zero));

                // Store result
                _mm256_store_si256(reinterpret_cast<__m256i*>(outPos), disps1);
                outPos += 32;
                _mm256_store_si256(reinterpret_cast<__m256i*>(outPos), disps2);
                outPos += 32;

                if(!alignedLoad && intPos >= intEndPos) {
                    // In the non-aligned case we might need one iteration less
                    break;
                }
            }
        }

        outPos += outRowPadding;
    }
}
# endif

void ImageProtocol::Pimpl::decodeSubpixelFallback(int startRow, int stopRow, const unsigned char* dispStart,
        const unsigned char* subpixStart, int width, unsigned short* dst, int srcStride, int dstStride) {

    int dstStrideShort =  dstStride/2;

    // Non-SSE version
    for(int y = startRow; y < stopRow; y++) {
        for(int x = 0; x < width; x++) {

            unsigned short subpix = 0;
            if(x % 2 == 0) {
                subpix = subpixStart[y*srcStride + x/2] & 0x0F;
            } else {
                subpix = subpixStart[y*srcStride + x/2] >> 4;
            }

            dst[y*dstStrideShort + x] = (static_cast<unsigned short>(dispStart[y*srcStride + x]) << 4) | subpix;
        }
    }
}
