/*******************************************************************************
 * Copyright (c) 2019 Nerian Vision GmbH
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *******************************************************************************/

#include <iostream>
#include <iomanip>
#include <cstdio>
#include <functional>
#include <stdexcept>
#include <string>
#include <chrono>

#include "nvcom.h"

using namespace std;
using namespace std::chrono;
using namespace cv;
using namespace visiontransfer;

NVCom::NVCom(const Settings& newSettings):
    terminateThreads(false), captureNextFrame(false), captureSequence(false),
    captureIndex(0), seqNum(0), minDisparity(0), maxDisparity(0), lastLeftFormat(0), lastRightFormat(0) {

    updateSettings(newSettings);
}

NVCom::~NVCom() {
    terminate();
}

void NVCom::terminate() {
    joinAllThreads();
    imageReader.reset(nullptr);
    asyncTrans.reset();
}

void NVCom::joinAllThreads() {
    terminateThreads = true;

    if(writingThread.joinable()) {
        writingCond.notify_all();
        writingThread.join();
    }

    if(mainLoopThread.joinable()) {
        mainLoopThread.join();
    }
}

void NVCom::connect() {
    // Just need to start the threads
    terminateThreads = false;
    mainLoopThread = std::thread(std::bind(&NVCom::mainLoop, this));
    writingThread = std::thread(std::bind(&NVCom::writeLoop, this));
}

void NVCom::mainLoop() {
    cv::Mat receivedLeftFrame, receivedRightFrame;
    int endTransmissionCtr = 0;
    bool connected = false;

    try {
        while(!terminateThreads) {
            // First establish network connection
            if(asyncTrans == nullptr) {
                asyncTrans.reset(new AsyncTransfer(
                    settings.remoteHost.c_str(), to_string(settings.remotePort).c_str(),
                    settings.tcp ? ImageProtocol::PROTOCOL_TCP : ImageProtocol::PROTOCOL_UDP));
                connectedCallback();
            }

            // Detect disconnects
            if(connected && !asyncTrans->isConnected()) {
                disconnectCallback();
            }
            connected = asyncTrans->isConnected();

            if(!transmitFrame() && endTransmissionCtr < 10) {
                // Try to yet receive frames for a little bit longer
                settings.readImages = false;
                endTransmissionCtr++;
                if(endTransmissionCtr == 10) {
                    sendCompleteCallback();
                }
            }

            if(!settings.disableReception) {
                ImagePair imagePair;
                receiveFrame(imagePair, receivedLeftFrame, receivedRightFrame);
                if(receivedLeftFrame.data == nullptr || receivedRightFrame.data == nullptr) {
                    continue;
                }

                int minDisp = 0, maxDisp = 0;
                imagePair.getDisparityRange(minDisp, maxDisp);

                if(minDisp != minDisparity || maxDisp != maxDisparity) {
                    minDisparity = minDisp;
                    maxDisparity = maxDisp;

                    // Force update of legend
                    convertedLeftFrame = cv::Mat();
                    convertedRightFrame = cv::Mat();
                    redBlueCoder.reset(new ColorCoder(ColorCoder::COLOR_RED_BLUE_RGB, minDisp*16,
                        maxDisp*16, true, true));
                    rainbowCoder.reset(new ColorCoder(ColorCoder::COLOR_RAINBOW_RGB, minDisp*16,
                        maxDisp*16, true, true));
                }

                colorCodeAndDisplay(receivedLeftFrame, receivedRightFrame, imagePair.isImageDisparityPair());
                captureFrameIfRequested(imagePair, receivedLeftFrame, receivedRightFrame);

                if(settings.displayCoordinate) {
                    unique_lock<mutex> lock(imagePairMutex);
                    imagePair.copyTo(lastImagePair);
                }
            }

            // Sleep for a while if we are processing too fast
            if(settings.readImages) {
                frameRateLimit->next();
            }
        }
    } catch(const std::exception& ex) {
        exceptionCallback(ex);
    }
}

void NVCom::colorCodeAndDisplay(const cv::Mat& receivedLeftFrame, const cv::Mat& receivedRightFrame,
        bool imageDisparityPair) {
    if(receivedRightFrame.data == nullptr) {
        return; // Can't display anything
    }

    // Prepare window resizing
    bool resize = false;
    if(lastFrameSize != receivedLeftFrame.size() ||
            lastLeftFormat != receivedLeftFrame.type() ||
            lastRightFormat != receivedRightFrame.type()) {
        // Size of buffers will change
        convertedLeftFrame = cv::Mat();
        convertedRightFrame = cv::Mat();
        lastFrameSize = receivedLeftFrame.size();
        lastLeftFormat = receivedLeftFrame.type();
        lastRightFormat = receivedRightFrame.type();
        resize = true;
    }

    // Perform color coding or just use the grayscale image
    {
        unique_lock<mutex> lock(displayMutex);
        convertFrame(receivedLeftFrame, convertedLeftFrame, false);
        convertFrame(receivedRightFrame, convertedRightFrame, imageDisparityPair);
        frameDisplayCallback(convertedLeftFrame, convertedRightFrame, resize);
    }
}

bool NVCom::transmitFrame() {
    std::shared_ptr<ImageReader::StereoFrame> stereoSendFrame;

    // Get frame to send if an image queue exists
    if(imageReader != nullptr) {
        stereoSendFrame = imageReader->pop();
        if(stereoSendFrame == nullptr) {
            // No more frames
            return false;
        }
    }

    // Transmit frame
    if(stereoSendFrame != nullptr) {
        ImagePair pair;
        pair.setWidth(stereoSendFrame->first.cols);
        pair.setHeight(stereoSendFrame->first.rows);
        pair.setRowStride(0, stereoSendFrame->first.step[0]);
        pair.setRowStride(1, stereoSendFrame->second.step[0]);
        // TODO: support for RGB images
        pair.setPixelFormat(0, stereoSendFrame->first.type() == CV_8U ?
            ImagePair::FORMAT_8_BIT_MONO : ImagePair::FORMAT_12_BIT_MONO);
        pair.setPixelFormat(1, stereoSendFrame->second.type() == CV_8U ?
            ImagePair::FORMAT_8_BIT_MONO : ImagePair::FORMAT_12_BIT_MONO);
        pair.setSequenceNumber(seqNum++);

        steady_clock::time_point time = steady_clock::now();
        long long microSecs = duration_cast<microseconds>(time.time_since_epoch()).count();
        pair.setTimestamp(microSecs / 1000000, microSecs % 1000000);

        // Clone image data such that we can delete the original
        unsigned char* leftPixel = new unsigned char[stereoSendFrame->first.step[0] * stereoSendFrame->first.rows];
        unsigned char* rightPixel = new unsigned char[stereoSendFrame->second.step[0] * stereoSendFrame->second.rows];
        memcpy(leftPixel, stereoSendFrame->first.data, stereoSendFrame->first.step[0] * stereoSendFrame->first.rows);
        memcpy(rightPixel, stereoSendFrame->second.data, stereoSendFrame->second.step[0] * stereoSendFrame->second.rows);

        pair.setPixelData(0, leftPixel);
        pair.setPixelData(1, rightPixel);

        asyncTrans->sendImagePairAsync(pair, true);
    }

    return true;
}

void NVCom::receiveFrame(ImagePair& imagePair, cv::Mat& receivedLeftFrame, cv::Mat& receivedRightFrame) {
    if(!asyncTrans->collectReceivedImagePair(imagePair, 0.1)) {
        // No image received yet
        imagePair = ImagePair();
        receivedLeftFrame = cv::Mat();
        receivedRightFrame = cv::Mat();
        return;
    }

    if(settings.printTimestamps) {
        int secs = 0, microsecs = 0;
        imagePair.getTimestamp(secs, microsecs);
        cout << (secs + (microsecs * 1e-6)) << endl;
    }

    // Convert received data to opencv images
    imagePair.toOpenCVImage(0, receivedLeftFrame, false);
    imagePair.toOpenCVImage(1, receivedRightFrame, false);
}

void NVCom::convertFrame(const cv::Mat& src, cv::Mat_<cv::Vec3b>& dst, bool colorCode) {
    if(src.type() == CV_16U) {
        if(!colorCode || settings.colorScheme == Settings::COLOR_SCHEME_NONE) {
            // Convert 16 to 8 bit
            cv::Mat_<unsigned char> image8Bit;
            src.convertTo(image8Bit, CV_8U, 1.0/16.0);
            cvtColor(image8Bit, dst, cv::COLOR_GRAY2RGB);
        } else {
            std::unique_ptr<ColorCoder>& coder = (
                settings.colorScheme == Settings::COLOR_SCHEME_RED_BLUE ? redBlueCoder : rainbowCoder);
            if(coder != nullptr) {
                // Perform color coding
                if(dst.data == nullptr) {
                    dst = coder->createLegendBorder(src.cols, src.rows, 1.0/16.0);
                }
                cv::Mat_<cv::Vec3b> dstSection = dst(Rect(0, 0, src.cols, src.rows));
                coder->codeImage((cv::Mat_<unsigned short>)src, dstSection);
            }
        }
    } else {
        // Just convert grey to 8 bit RGB
        if(src.channels() == 1) {
            cvtColor(src, dst, cv::COLOR_GRAY2RGB);
        } else {
            dst = src;
        }
    }
}

void NVCom::captureFrameIfRequested(const ImagePair& imagePair, const cv::Mat& receivedLeftFrame,
        const cv::Mat& receivedRightFrame) {

    if(captureNextFrame || captureSequence) {
        if(!captureSequence) {
            cout << "Writing frame " << captureIndex << endl;
        }

        captureNextFrame = false;

        if(receivedRightFrame.data != nullptr) {
            // Schedule write for left frame
            if(receivedLeftFrame.type() == CV_8U || (receivedLeftFrame.type() == CV_16U && settings.writeRaw16Bit)) {
                scheduleWrite(receivedLeftFrame.clone(), captureIndex, 0);
            } else {
                cv::Mat_<cv::Vec3b> rgbImage;
                cvtColor(convertedLeftFrame, rgbImage, cv::COLOR_RGB2BGR);
                scheduleWrite(rgbImage, captureIndex, 0);
            }

            // Schedule write for right frame
            if(receivedRightFrame.type() == CV_8U || (receivedRightFrame.type() == CV_16U && settings.writeRaw16Bit)) {
                scheduleWrite(receivedRightFrame.clone(), captureIndex, 1);
            } else {
                cv::Mat_<cv::Vec3b> rgbImage;
                cvtColor(convertedRightFrame, rgbImage, cv::COLOR_RGB2BGR);
                scheduleWrite(rgbImage, captureIndex, 1);
            }

            // Write timestamps
            int secs = 0, microsecs = 0;
            imagePair.getTimestamp(secs, microsecs);
            if(!timestampFile.is_open())  {
                timestampFile.open(settings.writeDir + "/timestamps.txt", ios::out);
                if(timestampFile.fail()) {
                    throw std::runtime_error("Unable to create timestamp file!");
                }
                timestampFile.fill('0');
            }
            timestampFile << captureIndex << "; " << secs << "." << setw(6) << microsecs << endl;

            // For simplicity, point clouds are not written asynchroneously
            if(settings.writePointCloud) {
                char fileName[19];
                snprintf(fileName, sizeof(fileName), "image%06d_3d.ply", captureIndex);
                recon3d.writePlyFile((settings.writeDir + "/" + fileName).c_str(), imagePair, settings.pointCloudMaxDist,
                    settings.binaryPointCloud);
            }
        }
        captureIndex++;
    }
}

void NVCom::writeLoop() {
    try {
        while(!terminateThreads) {
            pair<string, Mat> frame;
            {
                unique_lock<mutex> lock(writingMutex);
                while(writeQueue.size() == 0 && !terminateThreads) {
                    writingCond.wait(lock);
                }
                if(terminateThreads) {
                    return;
                }
                frame = writeQueue.front();
                writeQueue.pop();
            }
            if(frame.second.data != NULL && frame.second.cols != 0) {
                string filePath = settings.writeDir + "/" + string(frame.first);
                if(!imwrite(filePath, frame.second))
                    cerr << "Error writing file: " << filePath << endl;
            }
        }
    } catch(const std::exception& ex) {
        exceptionCallback(ex);
    }
}

void NVCom::scheduleWrite(const cv::Mat& frame, int index, int camera) {
    char fileName[19];
    snprintf(fileName, sizeof(fileName), "image%06d_c%d.png", index, camera);

    unique_lock<mutex> lock(writingMutex);
    writeQueue.push(pair<string, Mat>(fileName, frame));
    writingCond.notify_one();
}

void NVCom::updateSettings(const Settings& newSettings) {
    bool restartThreads = false;
    if(mainLoopThread.joinable() || writingThread.joinable()) {
        restartThreads = true;
        joinAllThreads();
    }

    if(newSettings.colorScheme != settings.colorScheme) {
        unique_lock<mutex> lock(displayMutex);

        // Make sure that a new color legend is created
        convertedLeftFrame = cv::Mat();
        convertedRightFrame = cv::Mat();
        lastFrameSize = cv::Size2i(0,0);
    }

    if(newSettings.writeDir  != settings.writeDir) {
        timestampFile.close();
        captureIndex = 0;
    }

    if(newSettings.readImages != settings.readImages ||
            newSettings.readDir != settings.readDir) {
        if(newSettings.readDir != "" && newSettings.readImages) {
            imageReader.reset(new ImageReader(newSettings.readDir.c_str(), 10, true));
        } else {
            imageReader.reset();
        }
    }

    frameRateLimit.reset(new RateLimit(newSettings.maxFrameRate));

    settings = newSettings;

    if(restartThreads) {
        connect();
    }
}

void NVCom::getBitDepths(int& bitsLeft, int& bitsRight) {
    for(int i=0; i<2; i++) {

        int& bits = (i == 0 ? bitsLeft : bitsRight);
        int format = (i == 0 ? lastLeftFormat : lastRightFormat);

        switch(format) {
            case CV_8U: bits = 8; break;
            case CV_16U: bits = 12; break;
            case CV_8UC3: bits = 24; break;
            default: bits = -1;
        }
    }
}

cv::Point3f NVCom::getDisparityMapPoint(int x, int y) {
    unique_lock<mutex> lock(imagePairMutex);

    if(!settings.displayCoordinate || !lastImagePair.isImageDisparityPair() ||
        lastImagePair.getPixelFormat(1) != ImagePair::FORMAT_12_BIT_MONO
        || x < 0 || y < 0 || x >= lastImagePair.getWidth() || y >= lastImagePair.getHeight()) {
        // This is not a valid disparity map
        return cv::Point3f(0, 0, 0);
    } else {
        unsigned short disp = reinterpret_cast<const unsigned short*>(
            lastImagePair.getPixelData(1))[y*lastImagePair.getRowStride(1)/2 + x];

        if(disp == 0xFFF) {
            // Invalid
            return cv::Point3f(0, 0, 0);
        }

        static Reconstruct3D recon3d;

        float x3d, y3d, z3d;
        recon3d.projectSinglePoint(x, y, disp, lastImagePair.getQMatrix(),
            x3d, y3d, z3d);
        return cv::Point3f(x3d, y3d, z3d);
    }
}
