/*******************************************************************************
 * Copyright (c) 2017 Nerian Vision Technologies
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *******************************************************************************/

#include <cstdio>
#include <iostream>
#include <cstring>
#include <memory>
#include <fcntl.h>
#include <string>
#include <vector>
#include "visiontransfer/imagetransfer.h"
#include "visiontransfer/exceptions.h"
#include "visiontransfer/datablockprotocol.h"


#ifdef _MSC_VER
    // Visual studio does not come with snprintf
    #define snprintf _snprintf_s
#endif

// Network headers
#ifdef _WIN32
    #ifndef _WIN32_WINNT
        #define _WIN32_WINNT 0x501
    #endif
    #define _WINSOCK_DEPRECATED_NO_WARNINGS
    
    #ifndef NOMINMAX
        #define NOMINMAX
    #endif

    #include <winsock2.h>
    #include <ws2tcpip.h>

    // Some defines to make windows socket look more like
    // posix sockets.
    #ifdef EWOULDBLOCK
        #undef EWOULDBLOCK
    #endif
    #ifdef ECONNRESET
        #undef ECONNRESET
    #endif
    #ifdef ETIMEDOUT
        #undef ETIMEDOUT
    #endif

    #define EWOULDBLOCK WSAEWOULDBLOCK
    #define ECONNRESET WSAECONNRESET
    #define ETIMEDOUT WSAETIMEDOUT
    #define MSG_DONTWAIT 0

    #define close closesocket

    // Emulate posix errno. Does not work in a throw
    // statement (WTF?)
    #undef errno
    #define errno WSAGetLastError()
    #define strerror win_strerror

    std::string win_strerror(unsigned long error) {
        char* str = nullptr;
        if(FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER |
            FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
            nullptr, error, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
            (LPSTR)&str, 0, nullptr) == 0 || str == nullptr) {
            return "Unknown error";
        } else {
            char buffer[512];
            snprintf(buffer, sizeof(buffer), "%s (%lu)", str, error);
            LocalFree(str);
            return std::string(buffer);
        }
    }

#else
    #include <arpa/inet.h>
    #include <netinet/tcp.h>
    #include <sys/types.h>
    #include <sys/socket.h>
    #include <netdb.h>
    #include <netinet/in.h>
    #include <errno.h>
    #include <unistd.h>
    #include <signal.h>

    // Unfortunately we have to use a winsock like socket type
    typedef int SOCKET;
    #define INVALID_SOCKET -1

    // Also we need some additional windock defines
    #define WSA_IO_PENDING 0
#endif

using namespace std;

/*************** Pimpl class containing all private members ***********/

class ImageTransfer::Pimpl {
public:
    Pimpl(OperationMode mode, const char* remoteAddress, const char* remoteService,
        const char* localAddress, const char* localService, int bufferSize);
    ~Pimpl();

    // Redeclaration of public members
    void setRawTransferData(const ImagePair& metaData, unsigned char* rawData,
        int secondTileWidth = 0, int validBytes = 0x7FFFFFFF);
    void setRawValidBytes(int validBytes);
    void setTransferImagePair(const ImagePair& imagePair);
    TransferStatus transferData(bool block);
    bool receiveImagePair(ImagePair& imagePair, bool block);
    bool receivePartialImagePair(ImagePair& imagePair, int& validRows, bool& complete, bool block);
    bool tryAccept();
    bool isClientConnected() const;
    void disconnect();
    std::string getClientAddress() const;

private:
    static constexpr int UDP_BUFFERS = 256;
    static constexpr int TCP_BUFFER_SIZE = 0xFFFF; //64K - 1

    // The chosen operation mode for this connection
    OperationMode mode;

    // The socket file descriptor
    SOCKET socket;

    // In server mode: Socket listening on the server port
    SOCKET serverSocket;

    // In server mode: Address if the connected client
    sockaddr_in clientAddress;

    // Address for UDP transmissions
    sockaddr_in udpAddress;

    // Object for encoding and decoding the network protocol
    std::unique_ptr<ImageProtocol> protocol;

    // Outstanding network message that still has to be transferred
    int currentMsgLen;
    int currentMsgOffset;
    const unsigned char* currentMsg;

    // Size of the socket buffers
    int bufferSize;

    // Buffered error state the current reception
    bool receptionFailed;
    
    bool socketIsBlocking;

    // Sets some required socket options for TCP sockets
    void setSocketOptions();

    // Initialization functions for different operation modes
    void initTcpClient(const addrinfo& remoteAddressInfo, const addrinfo& localAddressInfo);
    void initTcpServer(const addrinfo& localAddressInfo);
    void initUdp(const addrinfo& remoteAddressInfo, const addrinfo& localAddressInfo);

    // Receives some network data and forwards it to the protocol
    int receiveSingleNetworkMessages(bool block);
    bool receiveNetworkData(bool block);
    
    void win32SetBlocking(bool block);
};

/******************** Stubs for all public members ********************/

ImageTransfer::ImageTransfer(OperationMode mode, const char* remoteAddress,
        const char* remoteService, const char* localAddress, const char* localService, int bufferSize):
        pimpl(new Pimpl(mode, remoteAddress, remoteService, localAddress, localService, bufferSize)) {
    // All initialization in the pimpl class
}

ImageTransfer::~ImageTransfer() {
    delete pimpl;
}

void ImageTransfer::setRawTransferData(const ImagePair& metaData, unsigned char* rawData,
        int secondTileWidth, int validBytes) {
    pimpl->setRawTransferData(metaData, rawData, secondTileWidth, validBytes);
}

void ImageTransfer::setRawValidBytes(int validBytes) {
    pimpl->setRawValidBytes(validBytes);
}

void ImageTransfer::setTransferImagePair(const ImagePair& imagePair) {
    pimpl->setTransferImagePair(imagePair);
}

ImageTransfer::TransferStatus ImageTransfer::transferData(bool block) {
    return pimpl->transferData(block);
}

bool ImageTransfer::receiveImagePair(ImagePair& imagePair, bool block) {
    return pimpl->receiveImagePair(imagePair, block);
}

bool ImageTransfer::receivePartialImagePair(ImagePair& imagePair, int& validRows, bool& complete, bool block) {
    return pimpl->receivePartialImagePair(imagePair, validRows, complete, block);
}

bool ImageTransfer::tryAccept() {
    return pimpl->tryAccept();
}

bool ImageTransfer::isClientConnected() const {
    return pimpl->isClientConnected();
}

void ImageTransfer::disconnect() {
    pimpl->disconnect();
}

std::string ImageTransfer::getClientAddress() const {
    return pimpl->getClientAddress();
}

/******************** Implementation in pimpl class *******************/

ImageTransfer::Pimpl::Pimpl(OperationMode mode, const char* remoteAddress, const char* remoteService,
        const char* localAddress, const char* localService, int bufferSize)
        : mode(mode), socket(INVALID_SOCKET), serverSocket(INVALID_SOCKET), currentMsgLen(0),
        currentMsgOffset(0), currentMsg(nullptr), bufferSize(bufferSize), receptionFailed(false), socketIsBlocking(true) {

#ifdef _WIN32
    // In windows, we first have to initialize winsock
    WSADATA wsaData;
    if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
        throw TransferException("WSAStartup failed!");
    }
#else
    // We don't want to be interrupted by the pipe signal
    signal(SIGPIPE, SIG_IGN);
#endif

    // If address is null we use the any address
    if(remoteAddress == nullptr || string(remoteAddress) == "") {
        remoteAddress = "0.0.0.0";
    }
    if(localAddress == nullptr || string(localAddress) == "") {
        localAddress = "0.0.0.0";
    }

    // Resolve address
    addrinfo hints;
    memset(&hints, 0, sizeof(hints));
    hints.ai_family = AF_INET; // Use IPv4
    hints.ai_socktype = (mode == TCP_CLIENT || mode == TCP_SERVER) ? SOCK_STREAM : SOCK_DGRAM;
    hints.ai_flags = 0;
    hints.ai_protocol = 0;

    addrinfo* remoteAddressInfo = nullptr;
    if(getaddrinfo(remoteAddress, remoteService, &hints, &remoteAddressInfo) != 0 || remoteAddressInfo == nullptr) {
        TransferException ex("Error resolving remote address: " + string(strerror(errno)));
        throw ex;
    }

    addrinfo* localAddressInfo = nullptr;
    if(getaddrinfo(localAddress, localService, &hints, &localAddressInfo) != 0 || localAddressInfo == nullptr) {
        TransferException ex("Error resolving local address: " + string(strerror(errno)));
        throw ex;
    }

    // Perform initialization depending on the selected operation mode
    try {
        switch(mode) {
            case TCP_CLIENT: initTcpClient(*remoteAddressInfo, *localAddressInfo); break;
            case TCP_SERVER: initTcpServer(*localAddressInfo); break;
            case UDP: initUdp(*remoteAddressInfo, *localAddressInfo); break;
            default: throw TransferException("Illegal operation mode");
        }
    } catch(...) {
        freeaddrinfo(remoteAddressInfo);
        freeaddrinfo(localAddressInfo);
        throw;
    }

    freeaddrinfo(remoteAddressInfo);
    freeaddrinfo(localAddressInfo);
}

ImageTransfer::Pimpl::~Pimpl() {
    if(socket != INVALID_SOCKET) {
        close(socket);
    }

    if(serverSocket != INVALID_SOCKET) {
        close(serverSocket);
    }
}

void ImageTransfer::Pimpl::initTcpClient(const addrinfo& remoteAddressInfo, const addrinfo& localAddressInfo) {
    protocol.reset(new ImageProtocol(ImageProtocol::PROTOCOL_TCP));

    // Connect
    socket = ::socket(remoteAddressInfo.ai_family, remoteAddressInfo.ai_socktype,
        remoteAddressInfo.ai_protocol);
    if(socket == INVALID_SOCKET) {
        TransferException ex("Error creating socket: " + string(strerror(errno)));
        throw ex;
    }

    // Enable reuse address
    int enable = 1;
    setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&enable), sizeof(int));

    // Bind to local port
    if (::bind(socket, localAddressInfo.ai_addr, static_cast<int>(localAddressInfo.ai_addrlen)) < 0)  {
        TransferException ex("Error binding socket: " + string(strerror(errno)));
        throw ex;
    }

    // Perform connection
    if(connect(socket, remoteAddressInfo.ai_addr, static_cast<int>(remoteAddressInfo.ai_addrlen)) < 0) {
        TransferException ex("Error connection to destination address: " + string(strerror(errno)));
        throw ex;
    }

    // Set special socket options
    setSocketOptions();
}

void ImageTransfer::Pimpl::initTcpServer(const addrinfo& localAddressInfo) {
    protocol.reset(new ImageProtocol(ImageProtocol::PROTOCOL_TCP));

    // Create socket
    serverSocket = ::socket(localAddressInfo.ai_family, localAddressInfo.ai_socktype,
        localAddressInfo.ai_protocol);
    if (serverSocket == INVALID_SOCKET)  {
        TransferException ex("Error opening socket: " + string(strerror(errno)));
        throw ex;
    }

    // Enable reuse address
    int enable = 1;
    setsockopt(serverSocket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&enable), sizeof(int));

    // Open a server port
    if (::bind(serverSocket, localAddressInfo.ai_addr, static_cast<int>(localAddressInfo.ai_addrlen)) < 0)  {
        TransferException ex("Error binding socket: " + string(strerror(errno)));
        throw ex;
    }

    // Make the server socket non-blocking
#ifdef _WIN32
    unsigned long on = 1;
    ioctlsocket(serverSocket, FIONBIO, &on);
#else
    fcntl(serverSocket, F_SETFL, O_NONBLOCK);
#endif

    // Listen on port
    listen(serverSocket, 1);
}

void ImageTransfer::Pimpl::initUdp(const addrinfo& remoteAddressInfo, const addrinfo& localAddressInfo) {
    protocol.reset(new ImageProtocol(ImageProtocol::PROTOCOL_UDP));
    // Create socket
    socket = ::socket(localAddressInfo.ai_family, localAddressInfo.ai_socktype,
        localAddressInfo.ai_protocol);
    if(socket == INVALID_SOCKET) {
        TransferException ex("Error creating socket: " + string(strerror(errno)));
        throw ex;
    }

    // Enable reuse address
    int enable = 1;
    setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&enable), sizeof(int));

    // Bind socket to port
    if (::bind(socket, localAddressInfo.ai_addr, static_cast<int>(localAddressInfo.ai_addrlen)) < 0)  {
        TransferException ex("Error binding socket: " + string(strerror(errno)));
        throw ex;
    }

    // Store remote address
    if(remoteAddressInfo.ai_addrlen != sizeof(udpAddress)) {
        throw TransferException("Illegal address length");
    }
    memcpy(&udpAddress, remoteAddressInfo.ai_addr, sizeof(udpAddress));

    // Set special socket options
    setSocketOptions();
}

bool ImageTransfer::Pimpl::tryAccept() {
    if(mode != TCP_SERVER) {
        throw TransferException("Connections can only be accepted in tcp server mode");
    }

    // Accept one connection
    socklen_t clientAddressLength = sizeof(clientAddress);

    SOCKET newSocket = accept(serverSocket,
        reinterpret_cast<sockaddr *>(&clientAddress),
        &clientAddressLength);

    if(newSocket == INVALID_SOCKET) {
        if(errno == EWOULDBLOCK || errno == ETIMEDOUT) {
            // No connection
            return false;
        } else {
            TransferException ex("Error accepting connection: " + string(strerror(errno)));
            throw ex;
        }
    }

    // Close old socket and set new socket
    if(socket != INVALID_SOCKET) {
        close(socket);
    }
    socket = newSocket;

    // Set special socket options
    setSocketOptions();

    // Reset connection data
    protocol->resetTransfer();
    protocol->resetReception();

    return true;
}

std::string ImageTransfer::Pimpl::getClientAddress() const {
    if(socket == INVALID_SOCKET) {
        return ""; // No client connected
    }

    char strPort[11];
    snprintf(strPort, sizeof(strPort), ":%d", clientAddress.sin_port);

    return string(inet_ntoa(clientAddress.sin_addr)) + strPort;
}

void ImageTransfer::Pimpl::setSocketOptions() {
    if(mode == TCP_SERVER) {
        // Make sure the client socket didn't inherit the non-blocking mode from the server socket
#ifdef _WIN32
        unsigned long on = 0;
        ioctlsocket(socket, FIONBIO, &on);
#else
        fcntl(socket, F_SETFL, 0);
#endif
    }

    // Set the socket buffer sizes
    if(bufferSize > 0) {
        setsockopt(socket, SOL_SOCKET, SO_RCVBUF, reinterpret_cast<char*>(&bufferSize), sizeof(bufferSize));
        setsockopt(socket, SOL_SOCKET, SO_SNDBUF, reinterpret_cast<char*>(&bufferSize), sizeof(bufferSize));
    }

    // Set sending and receive timeouts
#ifdef _WIN32
    unsigned int timeout = 1000;
#else
    struct timeval timeout;
    timeout.tv_sec = 1;
    timeout.tv_usec = 0;
#endif

    setsockopt(socket, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char*>(&timeout), sizeof(timeout));
    setsockopt(socket, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<char*>(&timeout), sizeof(timeout));

    // Disable multicast loops for improved performance
    unsigned char loop = 0;
    setsockopt(socket, IPPROTO_IP, IP_MULTICAST_LOOP, reinterpret_cast<char*>(&loop), sizeof(loop));

    // Try to set a more suitable congestion control algorithm for TCP streams
#ifdef TCP_CONGESTION
    if(mode == TCP_SERVER || mode == TCP_CLIENT) {
        char optval[16];
        strcpy(optval, "westwood");
        if (setsockopt(socket, IPPROTO_TCP, TCP_CONGESTION, optval, strlen(optval)) < 0) {
            // Can't set westwood. Let's try reno
            strcpy(optval, "reno");
            setsockopt(socket, IPPROTO_TCP, TCP_CONGESTION, optval, strlen(optval));
        }
    }
#endif
}

void ImageTransfer::Pimpl::setRawTransferData(const ImagePair& metaData,
        unsigned char* rawData, int secondTileWidth, int validBytes) {
    protocol->setRawTransferData(metaData, rawData, secondTileWidth, validBytes);
    currentMsg = nullptr;
}

void ImageTransfer::Pimpl::setRawValidBytes(int validBytes) {
    protocol->setRawValidBytes(validBytes);
}

void ImageTransfer::Pimpl::setTransferImagePair(const ImagePair& imagePair) {
    protocol->setTransferImagePair(imagePair);
    currentMsg = nullptr;
}

void ImageTransfer::Pimpl::win32SetBlocking(bool block) {
#ifdef _WIN32
    if(block != socketIsBlocking) {
        // Windows doesn't support MSG_DONTWAIT. Have to touch the socket
        unsigned long on = block ? 0 : 1;
        ioctlsocket(socket, FIONBIO, &on);
        
        socketIsBlocking = block;
    }
#endif
}

ImageTransfer::TransferStatus ImageTransfer::Pimpl::transferData(bool block) {
    if(currentMsg == nullptr) {
        currentMsgOffset = 0;
        currentMsg = protocol->getTransferMessage(currentMsgLen);

        if(currentMsg == nullptr) {
            return NO_VALID_DATA;
        }
    }

    while(currentMsg != nullptr) {
        int writing = (int)(currentMsgLen - currentMsgOffset);
        int written = 0;

        win32SetBlocking(block);

        switch(mode) {
            case TCP_SERVER:
            case TCP_CLIENT:
                written = send(socket, reinterpret_cast<const char*>(&currentMsg[currentMsgOffset]), writing,
                    block ? 0 : MSG_DONTWAIT);
                break;
            case UDP:
                written = sendto(socket, reinterpret_cast<const char*>(&currentMsg[currentMsgOffset]), writing,
                    block ? 0 : MSG_DONTWAIT, reinterpret_cast<sockaddr*>(&udpAddress), sizeof(udpAddress));
                break;
        }

        unsigned long sendError = errno;

        if(written < 0) {
            if(!block && (sendError == EAGAIN || sendError == EWOULDBLOCK || sendError == ETIMEDOUT)) {
                // The socket is not yet ready for a new transfer
                return WOULD_BLOCK;
            }
            else if(sendError == ECONNRESET && (mode == TCP_SERVER || mode == TCP_CLIENT)) {
                // Connection closed by remote host
                close(socket);
                socket = INVALID_SOCKET;
                return CONNECTION_CLOSED;
            } else {
                TransferException ex("Error sending message: " + string(strerror(sendError)));
                throw ex;
            }
        } else if(written != writing) {
            // The message has been transmitted partially
            if(mode == UDP) {
                throw TransferException("Unable to transmit complete UDP message");
            }

            currentMsgOffset+= written;
            if(!block) {
                return PARTIAL_TRANSFER;
            }
        } else if(!block) {
            // Only one iteration allowed in non-blocking mode
            currentMsg = nullptr;
            break;
        } else {
            // Get next message
            currentMsg = protocol->getTransferMessage(currentMsgLen);
            currentMsgOffset = 0;
        }
    }

    if(mode == TCP_SERVER && protocol->transferComplete()) {
        // Force a flush by turning the nagle algorithm off and on
        int flag = 1;
        setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char *) &flag, sizeof(int));
        flag = 0;
        setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char *) &flag, sizeof(int));
    }

    if(protocol->transferComplete()) {
        return ALL_TRANSFERRED;
    } else {
        return PARTIAL_TRANSFER;
    }
}

bool ImageTransfer::Pimpl::receiveImagePair(ImagePair& imagePair, bool block) {
    int validRows = 0;
    bool complete = false;

    while(!complete) {
        if(!receivePartialImagePair(imagePair, validRows, complete, block)) {
            return false;
        }
    }

    return true;
}

bool ImageTransfer::Pimpl::receivePartialImagePair(ImagePair& imagePair,
        int& validRows, bool& complete, bool block) {
    if(receptionFailed) {
        // Notify about reception errors by returning false once
        receptionFailed = false;
        return false;
    }

    // Try to receive further image data
    while(!protocol->imagesReceived() && receiveNetworkData(block)) {
        block = false;
    }

    // Get received image
    return protocol->getPartiallyReceivedImagePair(imagePair, validRows, complete);
}

int ImageTransfer::Pimpl::receiveSingleNetworkMessages(bool block) {
    int maxLength = 0;
    char* buffer = reinterpret_cast<char*>(protocol->getNextReceiveBuffer(maxLength));

    int bytesReceived = recv(socket, buffer, maxLength,
#ifdef _WIN32
        0
#else
        block ? 0 : MSG_DONTWAIT
#endif
        );

    if(bytesReceived > 0) {
        if(!protocol->processReceivedMessage(bytesReceived)) {
            if(mode == TCP_CLIENT) {
                receptionFailed = true;
            }
        }
    }

    return bytesReceived;
}

bool ImageTransfer::Pimpl::receiveNetworkData(bool block) {
    win32SetBlocking(block);

    int received = receiveSingleNetworkMessages(block);
    bool ret = true;

    if(received == 0 && (mode == TCP_SERVER || mode == TCP_CLIENT)) {
        // Connection closed by remote host
        close(socket);
        socket = INVALID_SOCKET;
        ret = false;
    } else if(received < 0) {
        if(errno == EWOULDBLOCK || errno == EINTR || errno == ETIMEDOUT || errno == WSA_IO_PENDING) {
            // Reception was cancelled because it took too long,
            // or because of some signal. We reset reception for the current frame.
            ret = false;
        } else {
            TransferException ex("Error reading from socket: " + string(strerror(errno)));
            throw ex;
        }
    }

    return ret;
}

void ImageTransfer::Pimpl::disconnect() {
    if(mode != TCP_SERVER && mode != TCP_CLIENT) {
        throw TransferException("Only TCP transfers can be disconnected");
    }
    if(socket != INVALID_SOCKET) {
        close(socket);
        socket = INVALID_SOCKET;
    }
}

bool ImageTransfer::Pimpl::isClientConnected() const {
    return socket != INVALID_SOCKET;
}
