/*******************************************************************************
 * Copyright (c) 2019 Nerian Vision GmbH
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *******************************************************************************/

#include <algorithm>
#include <iostream>
#include <cstring>

#include "visiontransfer/datablockprotocol.h"
#include "visiontransfer/exceptions.h"

// Network headers
#ifdef _WIN32
#include <winsock2.h>
#else
#include <arpa/inet.h>
#endif

#define LOG_ERROR(expr)
//#define LOG_ERROR(expr) std::cout << "DataBlockProtocol: " << expr << std::endl

using namespace std;
using namespace visiontransfer;
using namespace visiontransfer::internal;

namespace visiontransfer {
namespace internal {

DataBlockProtocol::DataBlockProtocol(bool server, ProtocolType protType, int maxUdpPacketSize)
        : isServer(server), protType(protType),
        transferDone(true), rawData(nullptr), rawValidBytes(0),
        transferOffset(0), transferSize(0), overwrittenTransferData(0),
        overwrittenTransferIndex(-1), transferHeaderData(nullptr),
        transferHeaderSize(0), waitingForMissingSegments(false),
        totalReceiveSize(0), connectionConfirmed(false),
        confirmationMessagePending(false), eofMessagePending(false),
        clientConnectionPending(false), resendMessagePending(false),
        lastRemoteHostActivity(), lastSentHeartbeat(),
        lastReceivedHeartbeat(std::chrono::steady_clock::now()),
        receiveOffset(0), finishedReception(false), droppedReceptions(0),
        unprocessedMsgLength(0), headerReceived(false) {
    // Determine the maximum allowed payload size
    if(protType == PROTOCOL_TCP) {
        maxPayloadSize = MAX_TCP_BYTES_TRANSFER;
        minPayloadSize = 0;
    } else {
        maxPayloadSize = maxUdpPacketSize - sizeof(int);
        minPayloadSize = maxPayloadSize;
    }
    resizeReceiveBuffer();
}

void DataBlockProtocol::resetTransfer() {
    transferDone = true;
    overwrittenTransferIndex = -1;
    transferOffset = 0;
    transferSize = 0;
    missingTransferSegments.clear();
}

void DataBlockProtocol::setTransferHeader(unsigned char* data, int headerSize, int transferSize) {
    if(!transferDone && transferOffset > 0) {
        throw ProtocolException("Header data set while transfer is active!");
    } else if(headerSize + 9 > static_cast<int>(sizeof(controlMessageBuffer))) {
        throw ProtocolException("Transfer header is too large!");
    }

    transferDone = false;
    this->transferSize = transferSize;

    transferHeaderData = &data[-6];

    unsigned short netHeaderSize = htons(static_cast<unsigned short>(headerSize));
    memcpy(transferHeaderData, &netHeaderSize, sizeof(netHeaderSize));

    unsigned int netTransferSize = htonl(static_cast<unsigned int>(transferSize));
    memcpy(&transferHeaderData[2], &netTransferSize, sizeof(netTransferSize));
    headerSize += 6;

    if(protType == PROTOCOL_UDP) {
        // In UDP mode we still need to make this a control message
        transferHeaderData[headerSize++] = HEADER_MESSAGE;
        transferHeaderData[headerSize++] = 0xFF;
        transferHeaderData[headerSize++] = 0xFF;
        transferHeaderData[headerSize++] = 0xFF;
        transferHeaderData[headerSize++] = 0xFF;
    }

    transferHeaderSize = headerSize;
}

void DataBlockProtocol::setTransferData(unsigned char* data, int validBytes) {
    if(transferHeaderSize == 0 || transferHeaderData == nullptr) {
        throw ProtocolException("The transfer header has not yet been set!");
    }

    transferDone = false;
    rawData = data;
    transferOffset = 0;
    overwrittenTransferIndex = -1;
    rawValidBytes = min(transferSize, validBytes);
}

void DataBlockProtocol::setTransferValidBytes(int validBytes) {
    if(validBytes >= transferSize) {
        rawValidBytes = transferSize;
    } else if(validBytes < static_cast<int>(sizeof(int))) {
        rawValidBytes = 0;
    } else {
        rawValidBytes = validBytes;
    }
}

const unsigned char* DataBlockProtocol::getTransferMessage(int& length) {
    if(transferDone || rawValidBytes == 0) {
        // No more data to be transferred
        length = 0;
        return nullptr;
    }

    // For TCP we always send the header first
    if(protType == PROTOCOL_TCP && transferOffset == 0 && transferHeaderData != nullptr) {
        length = transferHeaderSize;
        const unsigned char* ret = transferHeaderData;
        transferHeaderData = nullptr;
        return ret;
    }

    // The transfer buffer might have been altered by the previous transfer
    // and first needs to be restored
    restoreTransferBuffer();

    // Determine which data segment to transfer next
    int offset;
    getNextTransferSegment(offset, length);
    if(length == 0) {
        return nullptr;
    }

    if(protType == PROTOCOL_UDP) {
        // For udp, we always append a segment offset
        overwrittenTransferIndex = offset + length;
        int* offsetPtr = reinterpret_cast<int*>(&rawData[offset + length]);
        overwrittenTransferData = *offsetPtr;
        *offsetPtr = static_cast<int>(htonl(offset));
        length += sizeof(int);
    }

    return &rawData[offset];
}

void DataBlockProtocol::getNextTransferSegment(int& offset, int& length) {
    if(missingTransferSegments.size() == 0) {
        // This is a regular data segment
        length = min(maxPayloadSize, rawValidBytes - transferOffset);
        if(length == 0 || (length < minPayloadSize && rawValidBytes != transferSize)) {
            length = 0;
            return;
        }

        offset = transferOffset;
        transferOffset += length; // for next transfer

        if(transferOffset >= transferSize && protType == PROTOCOL_UDP) {
            eofMessagePending = true;
        }
    } else {
        // This is a segment that is re-transmitted due to packet loss
        length = min(maxPayloadSize, missingTransferSegments.front().second);
        offset = missingTransferSegments.front().first;
        LOG_ERROR("Re-transmitting: " << offset << " -  " << (offset + length));

        int remaining = missingTransferSegments[0].second - length;
        if(remaining == 0) {
            // The segment is competed
            missingTransferSegments.pop_front();
        } else {
            // The segment is only partially complete
            missingTransferSegments.front().first += length;
            missingTransferSegments.front().second = remaining;
        }
    }
}

void DataBlockProtocol::restoreTransferBuffer() {
    if(overwrittenTransferIndex > 0) {
        *reinterpret_cast<int*>(&rawData[overwrittenTransferIndex]) = overwrittenTransferData;
    }
    overwrittenTransferIndex = -1;
}

bool DataBlockProtocol::transferComplete() {
    return transferOffset >= transferSize && !eofMessagePending;
}

int DataBlockProtocol::getMaxReceptionSize() const {
    if(protType == PROTOCOL_TCP) {
        return MAX_TCP_BYTES_TRANSFER;
    } else  {
        return MAX_UDP_RECEPTION;
    }
}

unsigned char* DataBlockProtocol::getNextReceiveBuffer(int maxLength) {
    if(static_cast<int>(receiveBuffer.size() - receiveOffset) < maxLength) {
        throw ProtocolException("No more receive buffers available!");
    }

    return &receiveBuffer[receiveOffset];
}

void DataBlockProtocol::processReceivedMessage(int length, bool& transferComplete) {
    transferComplete = false;
    if(length <= 0) {
        return; // Nothing received
    }

    if(finishedReception) {
        // First reset for next frame
        resetReception(false);
    }

    if(protType == PROTOCOL_UDP) {
        processReceivedUdpMessage(length, transferComplete);
    } else {
        processReceivedTcpMessage(length, transferComplete);
    }

    transferComplete = finishedReception;
}

void DataBlockProtocol::processReceivedUdpMessage(int length, bool& transferComplete) {
    if(length < static_cast<int>(sizeof(int)) ||
            receiveOffset + length > static_cast<int>(receiveBuffer.size())) {
        throw ProtocolException("Received message size is invalid!");
    }

    // Extract the sequence number
    int segmentOffset = ntohl(*reinterpret_cast<int*>(
        &receiveBuffer[receiveOffset + length - sizeof(int)]));

    if(segmentOffset == static_cast<int>(0xFFFFFFFF)) {
        // This is a control packet
        processControlMessage(length);
    } else if(segmentOffset < 0) {
        throw ProtocolException("Received illegal network packet");
    } else if(headerReceived) {
        // Correct the length by subtracting the size of the segment offset
        int payloadLength = length - sizeof(int);

        if(segmentOffset != receiveOffset) {
            // The segment offset doesn't match what we expected. Probably
            // a packet was dropped
            if(!waitingForMissingSegments && receiveOffset > 0 && segmentOffset > receiveOffset
                    && segmentOffset + payloadLength < (int)receiveBuffer.size()) {
                // We can just ask for a retransmission of this packet
                LOG_ERROR("Missing segment: " << receiveOffset << " - " << segmentOffset
                    << " (" << missingReceiveSegments.size() << ")");

                MissingReceiveSegment missingSeg;
                missingSeg.offset = receiveOffset;
                missingSeg.length = segmentOffset - receiveOffset;
                missingSeg.isEof = false;
                memcpy(missingSeg.subsequentData, &receiveBuffer[receiveOffset],
                    sizeof(missingSeg.subsequentData));
                missingReceiveSegments.push_back(missingSeg);

                // Move the received data to the right place in the buffer
                memcpy(&receiveBuffer[segmentOffset], &receiveBuffer[receiveOffset], payloadLength);
                receiveOffset = segmentOffset;
            } else {
                // In this case we cannot recover from the packet loss or
                // we just didn't get the EOF packet and everything is
                // actually fine
                resetReception(receiveOffset > 0);
                if(segmentOffset > 0 ) {
                    if(receiveOffset > 0) {
                        LOG_ERROR("Resend failed!");
                    }
                    return;
                } else {
                    LOG_ERROR("Missed EOF message!");
                }
            }
        }

        if(segmentOffset == 0) {
            // This is the beginning of a new frame
            lastRemoteHostActivity = std::chrono::steady_clock::now();
        }

        // Update the receive buffer offset
        receiveOffset = getNextUdpReceiveOffset(segmentOffset, payloadLength);
    }
}

int DataBlockProtocol::getNextUdpReceiveOffset(int lastSegmentOffset, int lastSegmentSize) {
    if(!waitingForMissingSegments) {
        // Just need to increment the offset during normal transmission
        return lastSegmentOffset + lastSegmentSize;
    } else {
        // Things get more complicated when re-transmitting dropped packets
        MissingReceiveSegment& firstSeg = missingReceiveSegments.front();
        if(lastSegmentOffset != firstSeg.offset) {
            LOG_ERROR("Received invalid resend: " << lastSegmentOffset);
            resetReception(true);
            return 0;
        } else {
            firstSeg.offset += lastSegmentSize;
            firstSeg.length -= lastSegmentSize;
            if(firstSeg.length == 0) {
                if(!firstSeg.isEof) {
                    memcpy(&receiveBuffer[firstSeg.offset + firstSeg.length],
                        firstSeg.subsequentData, sizeof(firstSeg.subsequentData));
                }
                missingReceiveSegments.pop_front();
            }

            if(missingReceiveSegments.size() == 0) {
                waitingForMissingSegments = false;
                finishedReception = true;
                return min(totalReceiveSize, static_cast<int>(receiveBuffer.size()));
            } else {
                return missingReceiveSegments.front().offset;
            }
        }
    }
}

void DataBlockProtocol::processReceivedTcpMessage(int length, bool& transferComplete) {
    // For TCP we might have some outstanding bytes from the
    // previous transfer. Lets copy that part from a separate buffer.
    if(unprocessedMsgLength != 0) {
        if(length + unprocessedMsgLength > MAX_OUTSTANDING_BYTES) {
            throw ProtocolException("Received too much data!");
        }

        ::memmove(&receiveBuffer[unprocessedMsgLength], &receiveBuffer[0], length);
        ::memcpy(&receiveBuffer[0], &unprocessedMsgPart[0], unprocessedMsgLength);
        length += unprocessedMsgLength;
        unprocessedMsgLength = 0;
    }

    // In TCP mode the header must be the first data item to be transmitted
    if(!headerReceived) {
        int totalHeaderSize = parseReceivedHeader(length, receiveOffset);
        if(totalHeaderSize == 0) {
            // Not yet enough data. Keep on buffering.
            ::memcpy(unprocessedMsgPart, &receiveBuffer[0], length);
            unprocessedMsgLength = length;
            return;
        } else {
            // Header successfully parsed
            // Move the remaining data to the beginning of the buffer
            length -= totalHeaderSize;
            if(length == 0) {
                return; // No more data remaining
            }
            ::memmove(&receiveBuffer[0], &receiveBuffer[totalHeaderSize], length);
        }
    }

    // The message might also contain extra bytes for the next
    // transfer. Lets copy that part into a separate buffer.
    if(receiveOffset + length > totalReceiveSize) {
        int newLength = static_cast<int>(totalReceiveSize - receiveOffset);

        if(unprocessedMsgLength != 0 || length - newLength > MAX_OUTSTANDING_BYTES) {
            throw ProtocolException("Received too much data!");
        }

        unprocessedMsgLength = length - newLength;
        ::memcpy(unprocessedMsgPart, &receiveBuffer[receiveOffset + newLength], unprocessedMsgLength);

        length = newLength;
    }

    // Advancing the receive offset in TCP mode just requires an increment
    receiveOffset += length;

    if(receiveOffset == totalReceiveSize) {
        // We are done once we received the expected amount of data
        finishedReception = true;
    }
}

int DataBlockProtocol::parseReceivedHeader(int length, int offset) {
    constexpr int headerExtraBytes = 6;

    if(length < headerExtraBytes) {
        return 0;
    }

    unsigned short headerSize = ntohs(*reinterpret_cast<unsigned short*>(&receiveBuffer[offset]));
    totalReceiveSize = static_cast<int>(ntohl(*reinterpret_cast<unsigned int*>(&receiveBuffer[offset + 2])));

    if(headerSize + headerExtraBytes > static_cast<int>(receiveBuffer.size())
            || totalReceiveSize < 0 || headerSize + headerExtraBytes > length ) {
        throw ProtocolException("Received invalid header!");
    }

    headerReceived = true;
    receivedHeader.assign(receiveBuffer.begin() + offset + headerExtraBytes,
        receiveBuffer.begin() + offset + headerSize + headerExtraBytes);
    resizeReceiveBuffer();

    return headerSize + headerExtraBytes;
}

void DataBlockProtocol::resetReception(bool dropped) {
    headerReceived = false;
    receiveOffset = 0;
    missingReceiveSegments.clear();
    receivedHeader.clear();
    waitingForMissingSegments = false;
    totalReceiveSize = 0;
    finishedReception = false;
    if(dropped) {
        droppedReceptions++;
    }
}

unsigned char* DataBlockProtocol::getReceivedData(int& length) {
    length = receiveOffset;
    if(missingReceiveSegments.size() > 0) {
        length = min(length, missingReceiveSegments[0].offset);
    }
    return &receiveBuffer[0];
}

unsigned char* DataBlockProtocol::getReceivedHeader(int& length) {
    if(receivedHeader.size() > 0) {
        length = static_cast<int>(receivedHeader.size());
        return &receivedHeader[0];
    } else {
        return nullptr;
    }
}

bool DataBlockProtocol::processControlMessage(int length) {
    if(length < static_cast<int>(sizeof(int) + 1)) {
        return false;
    }

    int payloadLength = length - sizeof(int) - 1;

    switch(receiveBuffer[receiveOffset + payloadLength]) {
        case CONFIRM_MESSAGE:
            // Our connection request has been accepted
            connectionConfirmed = true;
            break;
        case CONNECTION_MESSAGE:
            // We establish a new connection
            connectionConfirmed = true;
            confirmationMessagePending = true;
            clientConnectionPending = true;

            // A connection request is just as good as a heartbeat
            lastReceivedHeartbeat = std::chrono::steady_clock::now();
            break;
        case HEADER_MESSAGE: {
                int offset = receiveOffset;
                if(receiveOffset != 0) {
                    if(receiveOffset == totalReceiveSize) {
                        LOG_ERROR("No EOF message received!");
                    } else {
                        LOG_ERROR("Received header too late/early!");
                    }
                    resetReception(true);
                }
                if(parseReceivedHeader(payloadLength, offset) == 0) {
                    throw ProtocolException("Received header is too short!");
                }
            }
            break;
        case EOF_MESSAGE:
            // This is the end of the frame
            if(receiveOffset != 0) {
                parseEofMessage(length);
            }
            break;
        case RESEND_MESSAGE: {
            // The client requested retransmission of missing packets
            parseResendMessage(payloadLength);
            break;
        }
        case HEARTBEAT_MESSAGE:
            // A cyclic heartbeat message
            lastReceivedHeartbeat = std::chrono::steady_clock::now();
            break;
        default:
            throw ProtocolException("Received invalid control message!");
            break;
    }

    return true;
}

bool DataBlockProtocol::isConnected() const {
    if(protType == PROTOCOL_TCP) {
        // Connection is handled by TCP and not by us
        return true;
    } else if(connectionConfirmed) {
        return !isServer || std::chrono::duration_cast<std::chrono::milliseconds>(
            std::chrono::steady_clock::now() - lastReceivedHeartbeat).count()
        < 2*HEARTBEAT_INTERVAL_MS;
    } else return false;
}

const unsigned char* DataBlockProtocol::getNextControlMessage(int& length) {
    length = 0;

    if(protType == PROTOCOL_TCP) {
        // There are no control messages for TCP
        return nullptr;
    }

    if(confirmationMessagePending) {
        // Send confirmation message
        confirmationMessagePending = false;
        controlMessageBuffer[0] = CONFIRM_MESSAGE;
        length = 1;
    } else if(!isServer && std::chrono::duration_cast<std::chrono::milliseconds>(
            std::chrono::steady_clock::now() - lastRemoteHostActivity).count() > RECONNECT_TIMEOUT_MS) {
        // Send a new connection request
        controlMessageBuffer[0] = CONNECTION_MESSAGE;
        length = 1;

        // Also update time stamps
        lastRemoteHostActivity = lastSentHeartbeat = std::chrono::steady_clock::now();
    } else if(transferHeaderData != nullptr && isConnected()) {
        // We need to send a new protocol header
        length = transferHeaderSize;
        const unsigned char* ret = transferHeaderData;
        transferHeaderData = nullptr;
        return ret;
    } else if(eofMessagePending) {
        // Send end of frame message
        eofMessagePending = false;
        unsigned int networkOffset = htonl(static_cast<unsigned int>(transferOffset));
        memcpy(&controlMessageBuffer[0], &networkOffset, sizeof(int));
        controlMessageBuffer[sizeof(int)] = EOF_MESSAGE;
        length = 5;
    } else if(resendMessagePending) {
        // Send a re-send request for missing messages
        resendMessagePending = false;
        if(!generateResendRequest(length)) {
            length = 0;
            return nullptr;
        }
    } else if(!isServer && std::chrono::duration_cast<std::chrono::milliseconds>(
            std::chrono::steady_clock::now() - lastSentHeartbeat).count() > HEARTBEAT_INTERVAL_MS) {
        // Send a heartbeat message
        controlMessageBuffer[0] = HEARTBEAT_MESSAGE;
        length = 1;
        lastSentHeartbeat = std::chrono::steady_clock::now();
    } else {
        return nullptr;
    }

    // Mark this message as a control message
    controlMessageBuffer[length++] = 0xff;
    controlMessageBuffer[length++] = 0xff;
    controlMessageBuffer[length++] = 0xff;
    controlMessageBuffer[length++] = 0xff;
    return controlMessageBuffer;
}

bool DataBlockProtocol::newClientConnected() {
    if(clientConnectionPending) {
        clientConnectionPending = false;
        return true;
    } else {
        return false;
    }
}

bool DataBlockProtocol::generateResendRequest(int& length) {
    length = static_cast<int>(missingReceiveSegments.size() * (sizeof(int) + sizeof(unsigned short)));
    if(length + sizeof(int) + 1> sizeof(controlMessageBuffer)) {
        return false;
    }

    length = 0;
    for(MissingReceiveSegment segment: missingReceiveSegments) {
        unsigned int segOffset = htonl(static_cast<unsigned int>(segment.offset));
        unsigned int segLen = htonl(static_cast<unsigned int>(segment.length));

        memcpy(&controlMessageBuffer[length], &segOffset, sizeof(segOffset));
        length += sizeof(unsigned int);
        memcpy(&controlMessageBuffer[length], &segLen, sizeof(segLen));
        length += sizeof(unsigned int);
    }

    controlMessageBuffer[length++] = RESEND_MESSAGE;

    return true;
}

void DataBlockProtocol::parseResendMessage(int length) {
    missingTransferSegments.clear();

    int num = length / (sizeof(unsigned int) + sizeof(unsigned short));
    int bufferOffset = receiveOffset;

    for(int i=0; i<num; i++) {
        unsigned int segOffsetNet = *reinterpret_cast<unsigned int*>(&receiveBuffer[bufferOffset]);
        bufferOffset += sizeof(unsigned int);
        unsigned int segLenNet = *reinterpret_cast<unsigned int*>(&receiveBuffer[bufferOffset]);
        bufferOffset += sizeof(unsigned int);

        int segOffset = static_cast<int>(ntohl(segOffsetNet));
        int segLen = static_cast<int>(ntohl(segLenNet));

        if(segOffset >= 0 && segLen > 0 && segOffset + segLen < rawValidBytes) {
            missingTransferSegments.push_back(std::pair<int, int>(
                segOffset, segLen));
        }

        LOG_ERROR("Requested resend: " << segOffset << " - " << (segOffset + segLen));
    }
}

void DataBlockProtocol::parseEofMessage(int length) {
    if(length >= 4) {
        totalReceiveSize = static_cast<int>(ntohl(*reinterpret_cast<unsigned int*>(
            &receiveBuffer[receiveOffset])));
        if(totalReceiveSize < receiveOffset) {
            throw ProtocolException("Received invalid resend request");
        }
        if(totalReceiveSize != receiveOffset && receiveOffset != 0) {
            // Add final missing segment
            MissingReceiveSegment missingSeg;
            missingSeg.offset = receiveOffset;
            missingSeg.length = totalReceiveSize - receiveOffset;
            missingSeg.isEof = true;
            missingReceiveSegments.push_back(missingSeg);
        }
        if(missingReceiveSegments.size() > 0) {
            waitingForMissingSegments = true;
            resendMessagePending = true;
            receiveOffset = missingReceiveSegments[0].offset;
        } else {
            finishedReception = true;
        }
    }
}

void DataBlockProtocol::resizeReceiveBuffer() {
    if(totalReceiveSize < 0) {
        throw ProtocolException("Received invalid transfer size!");
    }

    // We increase the requested size to allow for one
    // additional network message and the protocol overhead
    int bufferSize = totalReceiveSize + getMaxReceptionSize()
        + MAX_OUTSTANDING_BYTES + sizeof(int);

    // Resize the buffer
    if(static_cast<int>(receiveBuffer.size()) < bufferSize) {
        receiveBuffer.resize(bufferSize);
    }
}

}} // namespace

